import copy
import os
from argparse import ArgumentParser, Namespace
from typing import Dict

import torch
from base import Algorithm
from data.get import get_dataset
from layers import TemperatureScaler
from maml.model import MAML, MAMLMiniImageNet, MAMLOmniglot
from torch.nn import functional as F
from torch.utils.data import DataLoader  # type: ignore
from tqdm import tqdm  # type: ignore
from utils import Stats, get_test_name, params_sans, seed, set_logger, str2bool

T = torch.Tensor


class ReptileTrainer(Algorithm):
    def __init__(self, args: Namespace, model: MAML, trainset: DataLoader, valset: DataLoader):
        super().__init__()

        self.args = args
        self.model = model.to(self.args.device)

        # reptile uses the same Adam statistics throughout training, so it is only initialized once here. It needs to be reset for the eval loops.
        self.inner_opt = torch.optim.Adam(params_sans(self.model, without=TemperatureScaler), lr=self.args.inner_lr, betas=(0, 0.999))
        self.optimizer = torch.optim.SGD(params_sans(self.model, without=TemperatureScaler), lr=1.0)
        self.trainset = trainset
        self.valset = valset
        self.iter = 0
        self.finished = False
        self.best_acc = 0.0

        self.results_path = os.path.join(
            "results", f"{args.dataset}", f"{self.model.name}", f"{args.norm_type}", f"{args.n_way}-way", f"{args.k_shot}-shot"
        )
        self.models_path = os.path.join(self.results_path, "models")
        for d in [self.results_path, self.models_path]:
            os.makedirs(d, exist_ok=True)

        self.tr_stats = Stats(["accuracy", "nll", "ece"])
        logs = [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-{get_test_name(self.args)}-run-{self.args.run}.txt"))] if args.ood_test else None
        stats = ["accuracy", "nll", "ece", "softmax_entropy"] + (["aupr", "auroc"] if args.ood_test else [])
        self.te_stats = Stats(stats, logs)

        if self.args.corrupt_test:
            self.corrupt_stats = [Stats(
                ["nll", "accuracy", "ece", "aupr", "auroc", "softmax_entropy"],
                [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-{get_test_name(self.args)}-run-{self.args.run}-level-{i}.txt"))]
            ) for i in range(6)]

    def fit(self) -> None:
        self.load_model(self.models_path)
        if self.finished:
            self.log("called fit() on a model which has finished training")
            return

        while self.iter < self.args.metatrain_iters:
            self.train()
            self.test()

            stats = self.log_train_stats(self.results_path)
            acc = stats["accuracy"]
            if not args.save_best_val:
                self.save_model(self.models_path)
                continue

            if acc > self.best_acc:
                self.log(f"({self.args.run}) new best accuracy: {acc:.4f}")
                self.best_acc = acc
                self.save_model(self.models_path)

        self.log("finished training, marking last saved model as finished")
        self.load_model(self.models_path)
        self.save_model(self.models_path, finished=True)

    def train(self) -> None:
        self.model.train()
        # reptile has no support and query set so we can just use the query set which should have the proper training shots
        for i, (_, _, xtrain, ytrain) in enumerate(tqdm(self.trainset, total=self.args.val_interval, ncols=75, leave=False)):
            state = copy.deepcopy(self.model.state_dict())
            grads = [torch.zeros_like(p) for p in self.model.parameters()]
            for (x, y) in zip(xtrain, ytrain):
                # limit the number of shots during training to be the number specified in the paper
                x, y = x.to(self.args.device), y.to(self.args.device)
                for _ in range(self.args.inner_train_steps):
                    # get a random inner batch
                    train_batch = torch.randperm(y.size(0))[:self.args.inner_train_batch_size]
                    loss = F.cross_entropy(self.model(x[train_batch]), y[train_batch])
                    self.inner_opt.zero_grad()
                    loss.backward()
                    self.inner_opt.step()

                # these are just for stats tracking, reptile doesn't have a separate tailed meta update. This has to be done
                # before we update the meta gradients and reload the model state.
                with torch.no_grad():
                    perm = torch.randperm(y.size(0))[:self.args.inner_train_batch_size]
                    x, y = x[perm], y[perm]
                    logit = self.model(x)
                    self.tr_stats.update_acc((logit.argmax(dim=-1) == y).sum().item(), y.size(0))
                    self.tr_stats.update_nll(logit, y)
                    self.tr_stats.update_ece(logit, y)

                phi_tilde = [torch.clone(p) for p in self.model.parameters()]
                self.model.load_state_dict(state)  # type: ignore
                grads = [grad + (phi_tilde - phi) for (grad, phi_tilde, phi) in zip(grads, phi_tilde, self.model.parameters())]

            lr = 1.0 - (self.iter / self.args.metatrain_iters)
            self.optimizer.zero_grad()
            for p, g in zip(self.model.parameters(), grads):
                p.grad = lr * (-g / (self.args.inner_train_steps * self.args.batch_size))
            self.optimizer.step()
            self.iter += 1
            if self.iter == self.args.metatrain_iters or self.iter % self.args.val_interval == 0:
                return

    def test(self) -> None:
        self.model.train()
        state, opt_state = copy.deepcopy(self.model.state_dict()), copy.deepcopy(self.inner_opt.state_dict())
        # during testing, we need to maintain the difference between the support and query sets, so use the regular sets like MAML
        for i, (xtrain, ytrain, xtest, ytest) in enumerate(tqdm(self.valset, ncols=75, leave=False)):
            for (xs, ys, xq, yq) in zip(xtrain, ytrain, xtest, ytest):
                xs, ys, xq, yq = xs.to(self.args.device), ys.to(self.args.device), xq.to(self.args.device), yq.to(self.args.device)
                self.model.load_state_dict(state)  # type: ignore
                self.inner_opt.load_state_dict(opt_state)
                for step in range(self.args.inner_train_steps):
                    # get a random inner batch
                    train_batch = torch.randperm(ys.size(0))[:self.args.inner_train_batch_size]
                    loss = F.cross_entropy(self.model(xs[train_batch]), ys[train_batch])
                    self.inner_opt.zero_grad()
                    loss.backward()
                    self.inner_opt.step()

                with torch.no_grad():
                    logit = self.model(xq)
                    logit = self.model.tmp_layer(logit)

                    # this is weird to switch these, but the naming was inconsistent from the mahalanobis file where this test code was taken from
                    # the easiest solution was to just swpa the names here.
                    qx, qy, preds = xq, yq, logit.softmax(dim=-1)
                    energy = -torch.logsumexp(-logit, dim=-1)

                    if self.args.ood_test:
                        n = qy.size(0) // 2

                        id_qy, ood_qy = torch.split(qy, n)
                        id_preds, ood_preds = torch.split(preds, n)

                        self.te_stats.update_acc((ood_preds.argmax(dim=-1) == ood_qy).sum().item(), ood_qy.size(0))
                        self.te_stats.update_nll(ood_preds, ood_qy, softmaxxed=True)
                        self.te_stats.update_ece(ood_preds, ood_qy, softmaxxed=True)
                        self.te_stats.update_softmax_entropy(ood_preds, ood_qy.size(0), softmaxxed=True)

                        # in ood, the first half of the query set is always the in distribution samples. The second half are the random classes.
                        y = torch.cat((torch.zeros(n, device=ood_qy.device), torch.ones(n, device=ood_qy.device)))
                        self.te_stats.update_aupr_auroc(y, energy)
                        self.te_stats.log_id_ood_entropy(y, preds, softmaxxed=True)
                    else:
                        self.te_stats.update_acc((logit.argmax(dim=-1) == yq).sum().item(), yq.size(0))
                        self.te_stats.update_nll(logit, yq)
                        self.te_stats.update_ece(logit, yq)
                        self.te_stats.update_softmax_entropy(logit, yq.size(0))

                    if self.args.corrupt_test:
                        n_way = self.args.n_way
                        n_corruptions = qx.size(0) // (args.val_query_shots * n_way)

                        preds = preds.view(n_way, n_corruptions, self.args.val_query_shots, n_way)
                        preds = torch.split(preds, 1, dim=1)
                        preds = [v.reshape(-1, n_way) for v in preds]
                        id_preds = preds[0]

                        energy = energy.view(n_way, n_corruptions, self.args.val_query_shots)
                        energy = torch.split(energy, 1, dim=1)
                        energy = [v.reshape(-1) for v in energy]  # type: ignore
                        id_energy = energy[0]

                        qy = qy.view(n_way, n_corruptions, self.args.val_query_shots)
                        qy = torch.split(qy, 1, dim=1)
                        qy = [v.reshape(-1) for v in qy]

                        for (stat, split_preds, split_energy, y) in zip(self.corrupt_stats, preds, energy, qy):
                            stat.update_acc((split_preds.argmax(dim=-1) == y).sum().item(), y.size(0))
                            stat.update_nll(split_preds, y, softmaxxed=True)
                            stat.update_ece(split_preds, y, softmaxxed=True)
                            stat.update_softmax_entropy(split_preds, y.size(0), softmaxxed=True)

                            y = torch.cat((torch.zeros(y.size(0), device=y.device), torch.ones(y.size(0), device=y.device)))
                            total_energy = torch.cat((id_energy, split_energy))
                            stat.update_aupr_auroc(y, total_energy)
                            stat.log_id_ood_entropy(y, torch.cat((id_preds, split_preds)), softmaxxed=True)

        # return model/opt to the original state before exiting
        self.model.load_state_dict(state)  # type: ignore
        self.inner_opt.load_state_dict(opt_state)

    def load_model(self, path: str) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        if os.path.exists(sd_path):
            saved = torch.load(sd_path)
            self.iter = saved["iter"]
            self.finished = saved["finished"]
            self.best_acc = saved["best_acc"]

            found = False
            for k in saved["state_dict"].keys():
                if "tmp_layer" in k:
                    found = True
                    break

            self.model.load_state_dict(saved["state_dict"])
            self.model.tuned = saved.get("tuned", False)
            if not found:
                self.model.tmp_layer = TemperatureScaler().to(self.args.device)

            self.inner_opt.load_state_dict(saved["inner_opt"])
            print(f"loaded saved model: {self.iter=} {self.model.tuned=}")

    def save_model(self, path: str, finished: bool = False) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        save = dict(
            iter=self.iter,
            finished=finished,
            state_dict=self.model.state_dict(),
            inner_opt=self.inner_opt.state_dict(),
            tuned=self.model.tuned,
            best_acc=self.best_acc
        )
        torch.save(save, sd_path)

    def log_train_stats(self, path: str) -> Dict[str, float]:
        self.tr_stats.log_stats(os.path.join(path, f"train-run-{self.args.run}.csv"))
        names, values = self.te_stats.log_stats(os.path.join(path, f"val-run-{self.args.run}.csv"))

        msg = f"({self.args.run}) iter: {self.iter}/{self.args.metatrain_iters} "
        for i, (n, v) in enumerate(zip(names, values)):
            msg += f"{n}: {v:.4f} "

        self.log(msg)
        return {n: v for (n, v) in zip(names, values)}

    def tune(self) -> None:
        # get the logits list from the validation (training set in this object)
        state, opt_state = copy.deepcopy(self.model.state_dict()), copy.deepcopy(self.inner_opt.state_dict())
        logit_lst, label_lst = [], []
        for i, (xtrain, ytrain, xtest, ytest) in enumerate(tqdm(self.trainset, ncols=75, leave=False)):
            for (xs, ys, xq, yq) in zip(xtrain, ytrain, xtest, ytest):
                self.model.train()
                xs, ys, xq, yq = xs.to(self.args.device), ys.to(self.args.device), xq.to(self.args.device), yq.to(self.args.device)
                self.model.load_state_dict(state)  # type: ignore
                self.inner_opt.load_state_dict(opt_state)
                for step in range(self.args.inner_train_steps):
                    # get a random inner batch
                    train_batch = torch.randperm(ys.size(0))[:self.args.inner_train_batch_size]
                    loss = F.cross_entropy(self.model(xs[train_batch]), ys[train_batch])
                    self.inner_opt.zero_grad()
                    loss.backward()
                    self.inner_opt.step()

                self.model.eval()
                with torch.no_grad():
                    logit = self.model(xq)
                    logit_lst.append(logit)
                    label_lst.append(yq)

            if i == 50:
                break

        # return model/opt to the original state before exiting
        self.model.load_state_dict(state)  # type: ignore
        self.inner_opt.load_state_dict(opt_state)

        # call the temperature scaling layer
        logits, labels = torch.cat(logit_lst).detach(), torch.cat(label_lst).detach()
        self.model.tmp_layer.tune(logits, labels)
        self.model.tuned = True

    def log_test_stats(self, path: str, test_name: str = "test") -> Dict[str, float]:
        names, values = self.te_stats.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}.csv"))
        self.log(" ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))

        if self.args.corrupt_test:
            for i, stat in enumerate(self.corrupt_stats):
                names, values = stat.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}-level-{i}.csv"))
                self.log(f"level {i} " + " ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))

        return {n: v for (n, v) in zip(names, values)}

    def log(self, msg: str) -> None:
        self.args.logger.info(msg)


if __name__ == "__main__":
    parser = ArgumentParser("argument parser for Reptile")

    parser.add_argument("--dataset", type=str, default="omniglot", choices=["omniglot", "miniimagenet"], help="the dataset to use")
    parser.add_argument("--data-root", type=str, help="the dataset root directory")
    parser.add_argument("--get-val", action="store_true", help="whether or not to get a validation split (only for datasets where no explicit validation set is given)")
    parser.add_argument("--ood-test", type=str2bool, default=False, help="if True, use random classes for the query set")
    parser.add_argument("--corrupt-test", type=str2bool, default=False, help="if True, use corrupted query set")
    parser.add_argument("--num-workers", type=int, default=4, help="the number of workers for the dataloader")
    parser.add_argument("--val-interval", type=int, default=500, help="validation interval during training")
    parser.add_argument("--n-way", type=int, default=5, help="the number of classes in each task")
    parser.add_argument("--k-shot", type=int, default=5, help="the number of exampels of each class")
    parser.add_argument("--filters", type=int, choices=[32, 64], help="the number of filters to use in the CNN layers")
    parser.add_argument("--train-query-shots", type=int, default=10, help="the number of training shots to provide for sub sampling")
    parser.add_argument("--val-query-shots", type=int, default=15, help="the number of testshots to use")
    parser.add_argument("--metatrain-iters", type=int, default=100000, help="the number of metatrain iters to run")
    parser.add_argument("--seed", type=int, default=0, help="random seed")
    parser.add_argument("--run", type=int, default=0, help="independent run number")
    parser.add_argument("--batch-size", type=int, default=32, help="the size of the metabatch which runs before each update")
    parser.add_argument("--inner-train-batch-size", type=int, default=10, help="the inner batch size for subsampling the support set during training")
    parser.add_argument("--inner-eval-batch-size", type=int, default=10, help="the inner batch size for subsampling the support set during eval")
    parser.add_argument("--inner-lr", type=float, default=0.1, help="the inner learn rate for adaptation")
    parser.add_argument("--inner-train-steps", type=int, default=1, help="the number of inner gradient steps for metatraining")
    parser.add_argument("--inner-val-steps", type=int, default=1, help="the number of inner gradient steps for metatesting")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index to run on")
    parser.add_argument("--mode", type=str, choices=["train", "test"])
    parser.add_argument("--norm-type", type=str, choices=["transductive", "reptile-norm"])
    parser.add_argument("--save-best-val", type=str2bool, default=False, help="whether or not to apply early stopping with validation")

    args = parser.parse_args()
    args.logger = set_logger("INFO")
    args.device = torch.device(f"cuda:{args.gpu}")

    # seed before doing anything else with the dataloaders
    seed(args.run)

    if args.ood_test and args.corrupt_test:
        raise ValueError("only one of odd test and corrupt test can be set at a time")

    if args.norm_type == "reptile-norm":
        raise NotImplementedError("reptile norm has not been implemented yet")

    trainset, valset, testset = get_dataset(args)

    # reptile uses the same model structure as MAML
    model_deref = {"omniglot": MAMLOmniglot, "miniimagenet": MAMLMiniImageNet}
    ds_deref = {"train": [trainset, valset], "test": [trainset, testset]}

    train, val = ds_deref[args.mode]
    trainer = ReptileTrainer(args, model_deref[args.dataset](args, filters=args.filters), trainset=train, valset=val)

    if args.mode == "train":
        trainer.fit()
        trainer.log("finished training")
    elif args.mode == "test":
        trainer.load_model(trainer.models_path)
        if not trainer.finished:
            raise ValueError("running tests on an unfinished model")

        if not trainer.model.tuned:
            trainer.tune()
            trainer.save_model(trainer.models_path, finished=True)  # will be saved with tuned flag set to true, temp buffer will be saved

        trainer.test()
        trainer.log_test_stats(trainer.results_path, test_name=get_test_name(args))
        trainer.log("finished testing")
    else:
        raise NotImplementedError()
