import os
from argparse import ArgumentParser, Namespace
from functools import partial
from typing import Any, Dict

import numpy as np
import pandas as pd  # type: ignore
import seaborn as sns  # type: ignore
import torch
from base import Algorithm
from data.get import get_dataset
from matplotlib import pyplot as plt  # type: ignore
from mpl_toolkits.axes_grid1 import ImageGrid  # type: ignore
from torch.nn import functional as F
from torch.utils.data import DataLoader  # type: ignore
from tqdm import tqdm  # type: ignore
from utils import (Stats, get_test_name, seed, set_logger,  # type: ignore
                   softmax_log_softmax_of_sample, str2bool, to_device)

from mahalanobis.models import (proto_ddu_cnn, proto_mahalanobis_cnn,
                                proto_sngp_cnn, protonet_cnn)

T = torch.Tensor

Model = Any


class MahalanobisTrainer(Algorithm):
    def __init__(self, args: Namespace, model: Model, trainset: DataLoader, valset: DataLoader):
        super().__init__()

        self.args = args
        self.model = model.to(self.args.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr)
        # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=0.9)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        self.trainset = trainset
        self.valset = valset
        self.best_acc = 0.
        self.iter = 0
        self.finished = False

        path_addendum = args.forward_type
        if args.model == "proto-mahalanobis":
            path_addendum = os.path.join(args.pma_type, f"t-{args.t}", path_addendum)

        self.results_path = os.path.join("results", args.comment, args.dataset, self.model.name, path_addendum, f"{self.args.n_way}-way", f"{self.args.k_shot}-shot")
        self.models_path = os.path.join(self.results_path, "models")
        for d in [self.results_path, self.models_path]:
            os.makedirs(d, exist_ok=True)

        self.tr_stats = Stats(["accuracy", "nll", "ece"])
        logs = [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-{get_test_name(self.args)}-run-{self.args.run}.txt"))] if args.ood_test else None
        stats = ["accuracy", "nll", "ece", "softmax_entropy"] + (["aupr", "auroc"] if args.ood_test else [])
        self.te_stats = Stats(stats, logs)

        if self.args.corrupt_test:
            self.corrupt_stats = [Stats(
                ["nll", "accuracy", "ece", "aupr", "auroc", "softmax_entropy"],
                [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-{get_test_name(self.args)}-run-{self.args.run}-level-{i}.txt"))]
            ) for i in range(6)]

    def fit(self) -> None:
        self.load_model(self.models_path)
        if self.finished:
            self.log("called fit() on a model which has finished training")
            return

        while self.iter < self.args.metatrain_iters:
            # re-initialize sigma and lambda here because during fitting we want to make sure they are
            # blank for the validation phase, but they should be saved in the current model for future testing
            self.train()
            self.test()

            stats = self.log_train_stats(self.results_path)
            acc = stats["accuracy"]
            if not args.save_best_val:
                self.save_model(self.models_path)
                continue

            if acc > self.best_acc:
                self.log(f"({self.args.run}) new best accuracy: {acc:.4f}")
                self.best_acc = acc
                self.save_model(self.models_path)

        self.log("finished training, marking last saved model as finished")
        self.load_model(self.models_path)
        self.save_model(self.models_path, finished=True)

    def train(self) -> None:
        self.model.train()
        for (x_spt, y_spt, x_qry, y_qry) in tqdm(self.trainset, ncols=75, leave=False, total=self.args.val_interval / self.args.batch_size):
            for (xs, ys, xq, yq) in zip(x_spt, y_spt, x_qry, y_qry):
                # the toy datasets put the n_way k_shot into the support tensor so we need to get that value if it exists
                n_way, k_shot = xs.n_way if hasattr(xs, "n_way") else self.args.n_way, xs.k_shot if hasattr(xs, "k_shot") else self.args.k_shot
                xs, ys, xq, yq = xs.to(self.args.device), ys.to(self.args.device), xq.to(self.args.device), yq.to(self.args.device),
                s, n = xq.size(0), n_way

                if self.args.ood_training:
                    xq = torch.cat((xq, torch.rand(n, *xq.size()[1:], device=xq.device) * 6 - 3))

                logit = self.model(xs, ys, xq, n_way=n_way, k_shot=k_shot)

                if self.args.forward_type in ["sigmoid", "exp"]:
                    yq_hot = F.one_hot(yq).float()
                    if self.args.ood_training:
                        yq_hot = torch.cat((yq_hot, torch.zeros(n, n_way, device=yq.device)))

                    loss = F.binary_cross_entropy(logit, yq_hot)
                elif self.args.forward_type == "softmax":
                    logit, _ = logit[:s], logit[s:]

                    loss = F.nll_loss(logit, yq)

                    if self.args.ood_training:
                        labels = torch.cat((torch.ones(s, dtype=torch.float32, device=xs.device), torch.zeros(n, dtype=torch.float32, device=xs.device)))
                        loss += F.binary_cross_entropy(torch.sigmoid(self.model.energy[:, 0]), labels)
                        # loss += 0.1 * ((ood_logit.amax() - np.log(1 / ood_logit.size(1))) ** 2).mean()
                else:
                    raise ValueError("inference type is not implemented")

                # print(loss.item())
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

                # remove the noise samples from here because we don't need to track stats on those
                # this is true no matter what, so there is no need to condition on anything
                logit = logit[:s]
                with torch.no_grad():
                    self.tr_stats.update_acc((logit.argmax(dim=-1) == yq).sum().item(), yq.size(0))
                    self.tr_stats.update_nll(logit, yq)
                    self.tr_stats.update_ece(logit, yq)
                    self.tr_stats.update_aupr_auroc(yq, logit)

                self.iter += 1
                if self.iter % self.args.val_interval == 0:
                    return

    def test(self) -> None:
        self.model.eval()
        with torch.no_grad():
            for (x_spt, y_spt, x_qry, y_qry) in tqdm(self.valset, ncols=75, leave=False):
                for (sx, sy, qx, qy) in zip(x_spt, y_spt, x_qry, y_qry):
                    # the toy datasets put the n_way k_shot into the support tensor so we need to get that value if it exists
                    n_way, k_shot = sx.n_way if hasattr(sx, "n_way") else self.args.n_way, sx.k_shot if hasattr(sx, "k_shot") else self.args.k_shot

                    sx, sy, qx, qy = to_device(sx, sy, qx, qy, device=self.args.device)
                    preds, log_preds, energy = self.model.inference(sx, sy, qx, n_way=n_way, k_shot=k_shot, inference_style=self.args.inference_style)

                    with torch.no_grad():
                        if self.args.ood_test:
                            n = qy.size(0) // 2
                            id_qy, ood_qy = torch.split(qy, n)
                            id_preds, ood_preds = torch.split(preds, n)
                            id_energy, ood_energy = torch.split(energy, n)

                            self.te_stats.update_acc((ood_preds.argmax(dim=-1) == ood_qy).sum().item(), ood_qy.size(0))
                            self.te_stats.update_nll(ood_preds, ood_qy, softmaxxed=True)
                            self.te_stats.update_ece(ood_preds, ood_qy, softmaxxed=True)
                            self.te_stats.update_softmax_entropy(ood_preds, ood_qy.size(0), softmaxxed=True)

                            # in ood, the first half of the query set is always the in distribution samples. The second half are the random classes.
                            y = torch.cat((torch.zeros(n, device=ood_qy.device), torch.ones(n, device=ood_qy.device)))
                            self.te_stats.update_aupr_auroc(y, energy)
                            self.te_stats.log_id_ood_entropy(y, preds, softmaxxed=True)
                        else:
                            self.te_stats.update_acc((preds.argmax(dim=-1) == qy).sum().item(), qy.size(0))
                            self.te_stats.update_nll(preds, qy, softmaxxed=True)
                            self.te_stats.update_ece(preds, qy, softmaxxed=True)
                            self.te_stats.update_softmax_entropy(preds, qy.size(0), softmaxxed=True)
                            # self.te_stats.update_aupr_auroc(qy, preds)

                        if self.args.corrupt_test:
                            n_corruptions = qx.size(0) // (args.val_query_shots * n_way)

                            preds = preds.view(n_way, n_corruptions, self.args.val_query_shots, n_way)
                            preds = torch.split(preds, 1, dim=1)
                            preds = [v.reshape(-1, n_way) for v in preds]
                            id_preds = preds[0]

                            energy = energy.view(n_way, n_corruptions, self.args.val_query_shots)
                            energy = torch.split(energy, 1, dim=1)
                            energy = [v.reshape(-1) for v in energy]
                            id_energy = energy[0]

                            qy = qy.view(n_way, n_corruptions, self.args.val_query_shots)
                            qy = torch.split(qy, 1, dim=1)
                            qy = [v.reshape(-1) for v in qy]

                            for (stat, split_preds, split_energy, y) in zip(self.corrupt_stats, preds, energy, qy):
                                stat.update_acc((split_preds.argmax(dim=-1) == y).sum().item(), y.size(0))
                                stat.update_nll(split_preds, y, softmaxxed=True)
                                stat.update_ece(split_preds, y, softmaxxed=True)
                                stat.update_softmax_entropy(split_preds, y.size(0), softmaxxed=True)

                                y = torch.cat((torch.zeros(y.size(0), device=y.device), torch.ones(y.size(0), device=y.device)))
                                total_energy = torch.cat((id_energy, split_energy))
                                stat.update_aupr_auroc(y, total_energy)
                                stat.log_id_ood_entropy(y, torch.cat((id_preds, split_preds)), softmaxxed=True)

    def tune(self) -> None:
        if "mahalanobis" in self.args.model:
            self.model.set_temperature(1.0)
            self.model.eval()

            _logits, _labels, _dists = [], [], []
            with torch.no_grad():
                for i, (x_spt, y_spt, x_qry, y_qry) in enumerate(tqdm(self.trainset, ncols=75, leave=False, total=50)):
                    for (sx, sy, qx, qy) in zip(x_spt, y_spt, x_qry, y_qry):
                        sx, sy, qx, qy = to_device(sx, sy, qx, qy, device=self.args.device)
                        l, d = self.model.get_logits(sx, sy, qx, n_way=self.args.n_way, k_shot=self.args.k_shot)
                        _logits.append(l)
                        _dists.append(d)
                        _labels.append(qy)

                    if i == 50:
                        break

            logits, labels, dists = torch.stack(_logits), torch.stack(_labels), torch.stack(_dists)  # type: ignore
            logits, labels, dists = logits.view(-1, logits.size(-1)), labels.view(-1), dists.view(-1, dists.size(-1))

            pred = logits.softmax(dim=-1)
            stats = Stats(["accuracy", "nll", "ece"])
            stats.update_acc((logits.argmax(dim=-1) == labels).sum().item(), labels.size(0))
            stats.update_nll(pred, labels, softmaxxed=True)
            stats.update_ece(pred, labels, softmaxxed=True)
            dist = stats.get_stats()

            def do_tune_run(stats: Stats, inference_style: str) -> Dict[str, Any]:
                with torch.no_grad():
                    # energy = self.model.t * self.model.sigma(dists / self.model.t)
                    # energy = self.model.sigma(dists / self.model.t)
                    energy = self.model.sigma(dists) / self.model.t

                    pred, i = [], 0
                    while i < logits.size(0):
                        s = torch.distributions.Normal(logits[i : i + 1000], energy[i : i + 1000]).sample((self.model.samples,))
                        p, _ = softmax_log_softmax_of_sample(s)
                        pred.append(p)
                        i += 1000

                    pred = torch.cat(pred)
                    stats.update_acc((pred.argmax(dim=-1) == labels).sum().item(), labels.size(0))
                    stats.update_nll(pred, labels, softmaxxed=True)
                    stats.update_ece(pred, labels, softmaxxed=True)
                    return stats.get_stats()

            temp = 1.0
            while True:
                self.model.set_temperature(temp)
                self.model.eval()
                t = do_tune_run(Stats(["accuracy", "nll", "ece"]), "softmax-sample")
                print(f"{dist['nll']=} <==({temp=})==> {t['nll']=}")
                if t["nll"] <= (dist["nll"] + 0.01):
                    break

                temp += 1

            self.model.set_temperature(temp)
            self.model.tuned = True
        else:
            self.model.tmp_layer.reset_parameters()
            self.model.tmp_layer.tune_few_shot(self.model, self.args.device, self.trainset, self.args.n_way, self.args.k_shot)
            self.model.tuned = True

    def load_model(self, path: str) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        if not os.path.exists(sd_path):
            self.log(f"no model found at: {sd_path}")
            return

        saved = torch.load(sd_path, map_location="cpu")

        self.iter = saved["iter"]
        self.best_acc = saved["best_acc"]

        self.model.load_state_dict(saved["state_dict"])
        self.model.tuned = saved.get("tuned", False)
        self.optimizer.load_state_dict(saved["optimizer"])
        self.scheduler.load_state_dict(saved["scheduler"])
        self.finished = saved["finished"]
        self.log(f"loaded saved model: {self.iter=} {self.finished=} {self.model.tuned=}")

    def save_model(self, path: str, finished: bool = False) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        save = dict(
            iter=self.iter,
            state_dict=self.model.state_dict(),
            optimizer=self.optimizer.state_dict(),
            scheduler=self.scheduler.state_dict(),
            tuned=self.model.tuned,
            finished=finished,
            best_acc=self.best_acc
        )
        torch.save(save, sd_path)

    def log_train_stats(self, path: str) -> Dict[str, float]:
        self.tr_stats.log_stats(os.path.join(path, f"train-run-{self.args.run}.csv"))
        names, values = self.te_stats.log_stats(os.path.join(path, f"val-run-{self.args.run}.csv"))

        msg = f"({self.args.run}) iter: {self.iter}/{self.args.metatrain_iters} "
        for i, (n, v) in enumerate(zip(names, values)):
            msg += f"{n}: {v:.4f} "

        self.log(msg)
        return {n: v for (n, v) in zip(names, values)}

    def log_test_stats(self, path: str, test_name: str = "test") -> Dict[str, float]:
        names, values = self.te_stats.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}.csv"))
        self.log(" ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))

        if self.args.corrupt_test:
            for i, stat in enumerate(self.corrupt_stats):
                names, values = stat.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}-level-{i}.csv"))
                self.log(f"level {i} " + " ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))

        return {n: v for (n, v) in zip(names, values)}

    def log(self, msg: str) -> None:
        self.args.logger.info(msg)

    def set_datasets(self, trainset: DataLoader, valset: DataLoader) -> None:
        self.trainset = trainset
        self.valset = valset

    def cov_experiment(self) -> None:
        with torch.no_grad():
            for (x_spt, y_spt, x_qry, y_qry) in tqdm(self.valset, ncols=75, leave=False):
                break

            sns.set(rc={'figure.figsize': (8, 7), "figure.facecolor": 'white'})
            sns.set_style("whitegrid")

            for i, (sx, sy, qx, qy) in enumerate(zip(x_spt, y_spt, x_qry, y_qry)):
                # the toy datasets put the n_way k_shot into the support tensor so we need to get that value if it exists
                sx, sy, qx, qy = to_device(sx, sy, qx, qy, device=self.args.device)
                if "protonet" not in self.args.model:
                    covs = self.model.get_cov(sx, sy, n_way=self.args.n_way, k_shot=self.args.k_shot)
                    covs = covs.detach().cpu()
                    covpath = os.path.join(self.results_path, "cov-prec-hist-final")
                    print(f"{covpath=}")
                    os.makedirs(covpath, exist_ok=True)

                    fig = plt.figure(figsize=(7, 6))
                    ax = ImageGrid(
                        fig, 111, nrows_ncols=(1, 1), axes_pad=0.05, share_all=True,
                        cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.05
                    )[0]

                    eigs = {v: [] for v in ["$\lambda$", "class"]}
                    for j in range(covs.size(0)):
                        cov, precision = covs[j], torch.inverse(covs[j])
                        for matrix, name in zip([cov, precision], ["covariance", "precision"]):
                            im = ax.imshow(matrix, cmap="viridis")
                            fig.tight_layout()
                            ax.axis("off")
                            cbar = ax.cax.colorbar(im)
                            cbar.ax.axis("off")
                            fig.savefig(os.path.join(covpath, f"{i}-{type(self.trainset.dataset).__name__}-{self.model.name}-{name}-class-{j}.pdf"))  # type: ignore
                            fig.savefig(os.path.join(covpath, f"{i}-{type(self.trainset.dataset).__name__}-{self.model.name}-{name}-class-{j}.png"))  # type: ignore
                            plt.close()

                        lambdas = torch.linalg.svdvals(precision).numpy()
                        for v in lambdas:
                            eigs["$\lambda$"].append(v)
                            eigs["class"].append(j)

                    df = pd.DataFrame(eigs)
                    ax, fig = plt.gca(), plt.gcf()
                    sns.histplot(df, x="$\lambda$", hue="class", element="step", bins=50, kde=True, ax=ax, palette="viridis")  # type: ignore

                    fontsize = 26
                    plt.setp(ax.get_legend().get_texts(), fontsize=fontsize)
                    plt.setp(ax.get_legend().get_title(), fontsize=fontsize)

                    ax.set_xlabel("$\lambda$", fontsize=fontsize)
                    ax.set_ylabel("Count", fontsize=fontsize)
                    ax.xaxis.set_tick_params(labelsize=fontsize)
                    ax.yaxis.set_tick_params(labelsize=fontsize)

                    fig.tight_layout()
                    fig.savefig(os.path.join(covpath, f"{i}-{type(self.trainset.dataset).__name__}-{self.model.name}-eigs.pdf"))
                    fig.savefig(os.path.join(covpath, f"{i}-{type(self.trainset.dataset).__name__}-{self.model.name}-eigs.png"))
                    plt.close()

                if i == 2:
                    return


if __name__ == "__main__":
    parser = ArgumentParser("argument parser for MAML")

    parser.add_argument("--dataset", type=str, default="omniglot", choices=["omniglot", "miniimagenet"], help="the dataset to use")
    parser.add_argument("--data-root", type=str, help="the dataset root directory")
    parser.add_argument("--ood-test", type=str2bool, default=False, help="if True, use random classes for the query set")
    parser.add_argument("--corrupt-test", type=str2bool, default=False, help="if True, use corrupted query set")
    parser.add_argument("--experiment-name", type=str, help="the name of an experiment to run (eval on trained model)")
    parser.add_argument("--num-workers", type=int, default=4, help="the number of workers for the dataloader")
    parser.add_argument("--pma-type", type=str, default="no-residual", choices=["multiplicative", "no-residual", "additive"], help="the type of residual connection in the PMA layer")
    parser.add_argument("--comment", type=str, default="", help="comment for experiment")
    parser.add_argument("--metatrain-iters", type=int, default=60000, help="the number of metatrain iters to run")
    parser.add_argument("--val-interval", type=int, default=500, help="the number of metatrain iters to run")
    parser.add_argument("--batch-size", type=int, default=16, help="batch size for training")
    parser.add_argument("--lr-steps", type=int, nargs="+", default=[45000])
    parser.add_argument("--weight-decay", type=float, default=0.0, help="the weight decay for the optimizer")
    parser.add_argument("--sigmoid-bias", type=str2bool, default=False, help="whether to put a bias on the sigmoid output for gradient stability")
    parser.add_argument("--ood-training", type=str2bool, default=False, help="whether or not to incldue OOD training.")
    parser.add_argument("--cov-experiment", type=str2bool, default=False, help="run the covariance experiment")
    parser.add_argument("--inference-style", type=str, default="distance", choices=["distance", "softmax-sample"], help="the inference style of the model")
    parser.add_argument("--forward-type", type=str, default="softmax", choices=["sigmoid", "exp", "softmax"], help="the forward pass type used during the training process")
    parser.add_argument("--lr-gamma", type=float, default=0.5)
    parser.add_argument("--p", type=float, default=0.1, help="the dropout rate, (filterwise for convolutions)")
    parser.add_argument("--t", type=float, default=0.01, help="the temperature parameter for lofit normal softmax sampling")
    parser.add_argument("--n-way", type=int, default=5, help="the number of classes in each task")
    parser.add_argument("--k-shot", type=int, default=5, help="the number of exampels of each class")
    parser.add_argument("--ctype", type=str, default="scalar", help="the scaling constant type for spectral normalization")
    parser.add_argument("--lr", type=float, default=1e-3, help="meta optimization learn rate")
    parser.add_argument("--run", type=int, default=0, help="independent run number")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
    parser.add_argument("--mode", type=str, choices=["train", "test"])
    parser.add_argument("--model", type=str, choices=["proto-mahalanobis", "proto-ddu", "protonet", "protonet-sn", "proto-sngp"])
    parser.add_argument("--encoder-type", type=str, default="diag", help="the encoder type for the proto mahalanobis model")
    parser.add_argument("--rank", type=int, default=1, help="the rank for the proto mahalanobis model (low rank version only)")
    parser.add_argument("--save-best-val", type=str2bool, default=False, help="whether or not to apply early stopping with validation")
    parser.add_argument("--train-query-shots", type=int, default=15, help="the number of testshots to use")
    parser.add_argument("--val-query-shots", type=int, default=15, help="the number of testshots to use")

    args = parser.parse_args()
    args.logger = set_logger("INFO")
    args.device = torch.device(f"cuda:{args.gpu}")
    args.get_val = True
    if args.ood_test and args.corrupt_test:
        raise ValueError("only one of odd test and corrupt test can be set at a time")

    print(f"{args.n_way=} {args.k_shot}")

    # seed before doing anything else with the dataloaders
    seed(args.run)

    trainset, valset, testset = get_dataset(args)

    # the spectral normalization for the convolutional layers needs to have the dimension at each layer given
    dim_deref = {
        "omniglot": ((1, 1, 28, 28), (1, 64, 14, 14), (1, 64, 7, 7), (1, 64, 4, 4)),
        "miniimagenet": ((1, 3, 84, 84), (1, 64, 42, 42), (1, 64, 21, 21), (1, 64, 10, 10)),
    }

    print(args.dataset)
    ch = {"omniglot": 1, "miniimagenet": 3}
    filters = {"omniglot": 64, "miniimagenet": 64}
    cov_dim = {"omniglot": 64, "miniimagenet": 1600}

    model_deref = {
        "proto-ddu": partial(proto_ddu_cnn, dim_deref[args.dataset], ch[args.dataset], filters[args.dataset], args.n_way, args.p, ctype="none", spectral=True, forward_type=args.forward_type, cov_dim=cov_dim[args.dataset]),
        "protonet": partial(protonet_cnn, dim_deref[args.dataset], ch[args.dataset], filters[args.dataset], args.n_way, args.p, ctype="none", spectral=False, forward_type=args.forward_type),
        "protonet-sn": partial(protonet_cnn, dim_deref[args.dataset], ch[args.dataset], filters[args.dataset], args.n_way, args.p, ctype="none", spectral=True, forward_type=args.forward_type),
        "proto-sngp": partial(proto_sngp_cnn, dim_deref[args.dataset], ch[args.dataset], filters[args.dataset], args.n_way, args.p, ctype="none", spectral=True, forward_type=args.forward_type, gp_in_dim=cov_dim[args.dataset], gp_h_dim=cov_dim[args.dataset]),
        "proto-mahalanobis": partial(
            proto_mahalanobis_cnn, dim_deref[args.dataset], ch[args.dataset], filters[args.dataset], args.n_way, args.p, ctype="none",
            spectral=True,
            encoder=args.encoder_type,
            rank=args.rank,
            beta=args.sigmoid_bias,
            forward_type=args.forward_type,
            pma_type=args.pma_type,
            t=args.t
        ),
    }
    ds_deref = {"train": [trainset, valset], "test": [valset, testset]}

    train, val = ds_deref[args.mode]
    trainer = MahalanobisTrainer(args, model_deref[args.model](), trainset=train, valset=val)

    if args.mode == "train":
        trainer.fit()
        trainer.log("finished training")
    elif args.cov_experiment:
        trainer.load_model(trainer.models_path)
        if not trainer.finished:
            raise ValueError("running tests on an unfinished model")

        trainer.cov_experiment()

    elif args.mode == "test":
        trainer.load_model(trainer.models_path)
        if not trainer.finished:
            raise ValueError("running tests on an unfinished model")

        if not trainer.model.tuned:
            trainer.tune()
            trainer.save_model(trainer.models_path, finished=True)  # will be saved with tuned flag set to true, temp buffer will be saved

        for inference_style in ["distance", "softmax-sample"]:
            args.inference_style = inference_style

            trainer.test()
            trainer.log_test_stats(trainer.results_path, test_name=get_test_name(args) + f"-{inference_style}")
            trainer.log(f"finished testing {inference_style=}")
    else:
        raise NotImplementedError()
