import itertools
import os
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, Iterator

import higher  # type: ignore
import torch
from base import Algorithm
from data.get import get_dataset
from layers import TemperatureScaler
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader  # type: ignore
from tqdm import tqdm  # type: ignore
from utils import (Stats, get_test_name, params_sans, seed,  # type: ignore
                   set_logger, str2bool)

from maml.model import MAML, MAMLMiniImageNet, MAMLOmniglot

T = torch.Tensor


class MAMLTrainer(Algorithm):
    def __init__(self, args: Namespace, model: MAML, trainset: DataLoader, valset: DataLoader):
        super().__init__()

        self.args = args
        self.model = model.to(self.args.device)
        self.optimizer = torch.optim.Adam(params_sans(self.model, without=TemperatureScaler), lr=args.lr)
        self.trainset = trainset
        self.valset = valset
        self.best_acc = 0.
        self.iter = 0
        self.finished = False

        order = "first-order" if self.args.first_order else "second-order"
        self.results_path = os.path.join("results", order, f"{args.dataset}", f"{self.model.name}", f"{args.n_way}-way", f"{args.k_shot}-shot")
        self.models_path = os.path.join(self.results_path, "models")
        for d in [self.results_path, self.models_path]:
            os.makedirs(d, exist_ok=True)

        self.tr_stats = Stats(["accuracy", "nll", "ece"])
        logs = [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-{get_test_name(self.args)}-run-{self.args.run}.txt"))] if args.ood_test else None
        stats = ["accuracy", "nll", "ece", "softmax_entropy"] + (["aupr", "auroc"] if args.ood_test else [])
        self.te_stats = Stats(stats, logs)

        if self.args.corrupt_test:
            self.corrupt_stats = [Stats(
                ["nll", "accuracy", "ece", "aupr", "auroc", "softmax_entropy"],
                [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-{get_test_name(self.args)}-run-{self.args.run}-level-{i}.txt"))]
            ) for i in range(6)]

    def fit(self) -> None:
        self.load_model(self.models_path)
        if self.finished:
            self.log("called fit() on a model which has finished training")
            return

        while self.iter < self.args.metatrain_iters:
            self.train()
            self.test()
            stats = self.log_train_stats(self.results_path)
            acc = stats["accuracy"]
            if not args.save_best_val:
                self.save_model(self.models_path)
                continue

            if acc > self.best_acc:
                self.log(f"({self.args.run}) new best accuracy: {acc:.4f}")
                self.best_acc = acc
                self.save_model(self.models_path)

        self.log("finished training, marking last saved model as finished")
        self.load_model(self.models_path)
        self.save_model(self.models_path, finished=True)

    def train(self) -> None:
        self.model.train()
        for xtrain, ytrain, xtest, ytest in tqdm(self.trainset, ncols=75, total=self.args.val_interval, leave=False):
            xtrain, ytrain, xtest, ytest = xtrain.to(self.args.device), ytrain.to(self.args.device), xtest.to(self.args.device), ytest.to(self.args.device)
            grads = [torch.zeros_like(p) for p in self.model.parameters()]
            opt = torch.optim.SGD(params_sans(self.model, without=nn.BatchNorm2d), lr=self.args.inner_lr)
            for i in range(xtrain.size(0)):
                with higher.innerloop_ctx(self.model, opt, copy_initial_weights=False, track_higher_grads=not self.args.first_order) as (fmodel, diffopt):
                    for step in range(self.args.inner_train_steps):
                        logit = fmodel(xtrain[i])
                        loss = F.cross_entropy(logit, ytrain[i])
                        diffopt.step(loss)

                    logit = fmodel(xtest[i])
                    loss = F.cross_entropy(logit, ytest[i]) / self.args.batch_size
                    newgrads = torch.autograd.grad(loss, fmodel.parameters()) if self.args.first_order \
                        else torch.autograd.grad(loss, fmodel.parameters(time=0))
                    grads = [g1 + g2 for (g1, g2) in zip(grads, newgrads)]

                    with torch.no_grad():
                        self.tr_stats.update_acc((logit.argmax(dim=-1) == ytest[i]).sum().item(), ytest[i].size(0))
                        self.tr_stats.update_nll(logit, ytest[i])
                        self.tr_stats.update_ece(logit, ytest[i])

            self.optimizer.zero_grad()
            for p, g in zip(self.model.parameters(), grads):
                if p.grad is None:
                    p.grad = g / self.args.batch_size
                    continue

                p.grad.add_(g / self.args.batch_size)
            nn.utils.clip_grad_value_(self.model.parameters(), 10)
            self.optimizer.step()

            self.iter += 1
            if self.iter == self.args.metatrain_iters or self.iter % self.args.val_interval == 0:
                return

    def test(self, n: int = -1) -> None:
        total = len(self.valset) if n == -1 else n
        self.model.train()
        for it, (xtrain, ytrain, xtest, ytest) in enumerate(tqdm(self.valset, ncols=75, total=total, leave=False)):
            xtrain, ytrain, xtest, ytest = xtrain.to(self.args.device), ytrain.to(self.args.device), xtest.to(self.args.device), ytest.to(self.args.device)
            opt = torch.optim.SGD(params_sans(self.model, without=nn.BatchNorm2d), lr=self.args.inner_lr)
            for i in range(xtrain.size(0)):
                with higher.innerloop_ctx(self.model, opt, track_higher_grads=False) as (fmodel, fopt):
                    for step in range(self.args.inner_val_steps):
                        logit = fmodel(xtrain[i])
                        loss = F.cross_entropy(logit, ytrain[i])
                        fopt.step(loss)

                    with torch.no_grad():
                        logit = fmodel(xtest[i])
                        logit = fmodel.tmp_layer(logit)

                        qx, qy, preds = xtest[i], ytest[i], logit.softmax(dim=-1)
                        energy = -torch.logsumexp(-logit, dim=-1)

                        if self.args.ood_test:
                            n = qy.size(0) // 2

                            id_qy, ood_qy = torch.split(qy, n)
                            id_preds, ood_preds = torch.split(preds, n)

                            self.te_stats.update_acc((ood_preds.argmax(dim=-1) == ood_qy).sum().item(), ood_qy.size(0))
                            self.te_stats.update_nll(ood_preds, ood_qy, softmaxxed=True)
                            self.te_stats.update_ece(ood_preds, ood_qy, softmaxxed=True)
                            self.te_stats.update_softmax_entropy(ood_preds, ood_qy.size(0), softmaxxed=True)

                            # in ood, the first half of the query set is always the in distribution samples. The second half are the random classes.
                            y = torch.cat((torch.zeros(n, device=ood_qy.device), torch.ones(n, device=ood_qy.device)))
                            self.te_stats.update_aupr_auroc(y, energy)
                            self.te_stats.log_id_ood_entropy(y, preds, softmaxxed=True)
                        else:
                            self.te_stats.update_nll(logit, ytest[i])
                            self.te_stats.update_acc((logit.argmax(dim=-1) == ytest[i]).sum().item(), ytest[i].size(0))
                            self.te_stats.update_ece(logit, ytest[i])
                            self.te_stats.update_softmax_entropy(logit, ytest[i].size(0))

                        if self.args.corrupt_test:
                            n_way = self.args.n_way
                            n_corruptions = qx.size(0) // (args.val_query_shots * n_way)

                            preds = preds.view(n_way, n_corruptions, self.args.val_query_shots, n_way)
                            preds = torch.split(preds, 1, dim=1)
                            preds = [v.reshape(-1, n_way) for v in preds]
                            id_preds = preds[0]

                            energy = energy.view(n_way, n_corruptions, self.args.val_query_shots)
                            energy = torch.split(energy, 1, dim=1)
                            energy = [v.reshape(-1) for v in energy]  # type: ignore
                            id_energy = energy[0]

                            qy = qy.view(n_way, n_corruptions, self.args.val_query_shots)
                            qy = torch.split(qy, 1, dim=1)
                            qy = [v.reshape(-1) for v in qy]

                            for (stat, split_preds, split_energy, y) in zip(self.corrupt_stats, preds, energy, qy):
                                stat.update_acc((split_preds.argmax(dim=-1) == y).sum().item(), y.size(0))
                                stat.update_nll(split_preds, y, softmaxxed=True)
                                stat.update_ece(split_preds, y, softmaxxed=True)
                                stat.update_softmax_entropy(split_preds, y.size(0), softmaxxed=True)

                                y = torch.cat((torch.zeros(y.size(0), device=y.device), torch.ones(y.size(0), device=y.device)))
                                total_energy = torch.cat((id_energy, split_energy))
                                stat.update_aupr_auroc(y, total_energy)
                                stat.log_id_ood_entropy(y, torch.cat((id_preds, split_preds)), softmaxxed=True)

            if it == total and n != -1:
                break

    def load_model(self, path: str) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        if os.path.exists(sd_path):
            saved = torch.load(sd_path)
            self.iter = saved["iter"]
            self.finished = saved["finished"]
            self.best_acc = saved["best_acc"]

            found = False
            for k in saved["state_dict"].keys():
                if "tmp_layer" in k:
                    found = True
                    break

            self.model.load_state_dict(saved["state_dict"])
            if not found:
                self.model.tmp_layer = TemperatureScaler().to(self.args.device)

            self.model.tuned = saved.get("tuned", False)
            self.optimizer.load_state_dict(saved["optimizer"])
            print(f"loaded saved model: {self.iter=} {self.finished=} {self.model.tuned=}")

    def save_model(self, path: str, finished: bool = False) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        save = dict(
            iter=self.iter,
            finished=finished,
            state_dict=self.model.state_dict(),
            optimizer=self.optimizer.state_dict(),
            tuned=self.model.tuned,
            best_acc=self.best_acc
        )
        torch.save(save, sd_path)

    def log_train_stats(self, path: str) -> Dict[str, float]:
        self.tr_stats.log_stats(os.path.join(path, f"train-run-{self.args.run}.csv"))
        names, values = self.te_stats.log_stats(os.path.join(path, f"val-run-{self.args.run}.csv"))

        msg = f"({self.args.run}) iter: {self.iter}/{self.args.metatrain_iters} "
        for i, (n, v) in enumerate(zip(names, values)):
            msg += f"{n}: {v:.4f} "

        self.log(msg)
        return {n: v for (n, v) in zip(names, values)}

    def tune(self) -> None:
        # get the logits list from the validation (training set in this object)
        total = len(self.valset)
        logit_lst, label_lst = [], []
        for it, (xtrain, ytrain, xtest, ytest) in enumerate(tqdm(self.trainset, ncols=75, total=total, leave=False)):
            xtrain, ytrain, xtest, ytest = xtrain.to(self.args.device), ytrain.to(self.args.device), xtest.to(self.args.device), ytest.to(self.args.device)
            opt = torch.optim.SGD(params_sans(self.model, without=nn.BatchNorm2d), lr=self.args.inner_lr)
            for i in range(xtrain.size(0)):
                self.model.train()
                with higher.innerloop_ctx(self.model, opt, track_higher_grads=False) as (fmodel, fopt):
                    for step in range(self.args.inner_val_steps):
                        logit = fmodel(xtrain[i])
                        loss = F.cross_entropy(logit, ytrain[i])
                        fopt.step(loss)

                    self.model.eval()
                    with torch.no_grad():
                        logit = fmodel(xtest[i])

                        logit_lst.append(logit)
                        label_lst.append(ytest[i])

            if it == 50:
                break

        # call the temperature scaling layer
        logits, labels = torch.cat(logit_lst), torch.cat(label_lst)
        self.model.tmp_layer.tune(logits, labels)
        self.model.tuned = True

    def log_test_stats(self, path: str, test_name: str = "test") -> Dict[str, float]:
        names, values = self.te_stats.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}.csv"))
        self.log(" ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))

        if self.args.corrupt_test:
            for i, stat in enumerate(self.corrupt_stats):
                names, values = stat.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}-level-{i}.csv"))
                self.log(f"level {i} " + " ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))

        return {n: v for (n, v) in zip(names, values)}

    def log(self, msg: str) -> None:
        self.args.logger.info(msg)


if __name__ == "__main__":
    parser = ArgumentParser("argument parser for MAML")

    parser.add_argument("--dataset", type=str, default="omniglot", choices=["omniglot", "miniimagenet"], help="the dataset to use")
    parser.add_argument("--data-root", type=str, help="the dataset root directory")
    parser.add_argument("--ood-test", type=str2bool, default=False, help="if True, use random classes for the query set")
    parser.add_argument("--corrupt-test", type=str2bool, default=False, help="if True, use corrupted query set")
    parser.add_argument("--num-workers", type=int, default=4, help="the number of workers for the dataloader")
    parser.add_argument("--n-way", type=int, default=5, help="the number of classes in each task")
    parser.add_argument("--k-shot", type=int, default=5, help="the number of exampels of each class")
    parser.add_argument("--train-query-shots", type=int, default=15, help="the number of testshots to use")
    parser.add_argument("--val-query-shots", type=int, default=15, help="the number of testshots to use")
    parser.add_argument("--val-interval", type=int, default=500, help="validation interval during training")
    parser.add_argument("--first-order", type=str2bool, default=False, help="whether or not to use first order gradients")
    parser.add_argument("--metatrain-iters", type=int, default=60000, help="the number of metatrain iters to run")
    parser.add_argument("--run", type=int, default=0, help="independent run number")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
    parser.add_argument("--batch-size", type=int, default=32, help="the size of the metabatch which runs before each update")
    parser.add_argument("--inner-lr", type=float, default=0.1, help="the inner learn rate for adaptation")
    parser.add_argument("--inner-train-steps", type=int, default=1, help="the number of inner gradient steps for metatraining")
    parser.add_argument("--inner-val-steps", type=int, default=1, help="the number of inner gradient steps for metatesting")
    parser.add_argument("--lr", type=float, default=1e-4, help="meta optimization learn rate")
    parser.add_argument("--mode", type=str, choices=["train", "test"])
    parser.add_argument("--save-best-val", type=str2bool, default=False, help="whether or not to apply early stopping with validation")

    args = parser.parse_args()
    args.logger = set_logger("INFO")
    args.device = torch.device(f"cuda:{args.gpu}")

    # seed before doing anything else with the dataloaders
    seed(args.run)
    if args.ood_test and args.corrupt_test:
        raise ValueError("only one of odd test and corrupt test can be set at a time")

    trainset, valset, testset = get_dataset(args)

    model_deref = {"omniglot": MAMLOmniglot, "miniimagenet": MAMLMiniImageNet}
    ds_deref = {"train": [trainset, valset], "test": [valset, testset]}

    train, val = ds_deref[args.mode]
    trainer = MAMLTrainer(args, model_deref[args.dataset](args), trainset=train, valset=val)

    if args.mode == "train":
        trainer.fit()
        trainer.log("finished training")
    elif args.mode == "test":
        trainer.load_model(trainer.models_path)
        if not trainer.finished:
            raise ValueError("running tests on an unfinished model")

        if not trainer.model.tuned:
            trainer.tune()
            trainer.save_model(trainer.models_path, finished=True)  # will be saved with tuned flag set to true, temp buffer will be saved

        trainer.test()
        trainer.log_test_stats(trainer.results_path, test_name=get_test_name(args))
        trainer.log("finished testing")
    else:
        raise NotImplementedError()
