import os
from argparse import ArgumentParser, Namespace
from typing import Dict

import torch
from base import Algorithm
from data.get import get_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader  # type: ignore
from tqdm import tqdm  # type: ignore
from utils import (Stats, get_test_name, seed, set_logger,  # type: ignore
                   str2bool, to_device)

from protonet.model import Protonet, ProtonetCNN4

T = torch.Tensor


class ProtonetTrainer(Algorithm):
    def __init__(self, args: Namespace, model: Protonet, trainset: DataLoader, valset: DataLoader):
        super().__init__()

        self.args = args
        self.model = model.to(self.args.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        self.trainset = trainset
        self.valset = valset
        self.best_acc = 0.
        self.iter = 0
        self.finished = False

        self.results_path = os.path.join("results", f"{args.dataset}", f"{self.model.name}", f"{self.args.n_way}-way", f"{self.args.k_shot}-shot")
        self.models_path = os.path.join(self.results_path, "models")
        for d in [self.results_path, self.models_path]:
            os.makedirs(d, exist_ok=True)

        self.tr_stats = Stats(["accuracy", "loss", "nll", "ece", "aupr", "auroc"])
        self.te_stats = Stats(["accuracy", "loss", "nll", "ece", "aupr", "auroc"])

    def fit(self) -> None:
        self.load_model(self.models_path)
        if self.finished:
            self.log("called fit() on a model which has finished training")
            return

        while self.iter < self.args.metatrain_iters:
            self.train()
            self.test()

            stats = self.log_train_stats(self.results_path)
            acc = stats["accuracy"]
            if not args.save_best_val:
                self.save_model(self.models_path)
                continue

            if acc > self.best_acc:
                self.log(f"({self.args.run}) new best accuracy: {acc:.4f}")
                self.best_acc = acc
                self.save_model(self.models_path)

        self.log("finished training, marking last saved model as finished")
        self.load_model(self.models_path)
        self.save_model(self.models_path, finished=True)

    def train(self) -> None:
        self.model.train()
        for (x_spt, y_spt, x_qry, y_qry) in tqdm(self.trainset, ncols=75, leave=False, total=self.args.val_interval // self.args.batch_size):
            for (sx, sy, qx, qy) in zip(x_spt, y_spt, x_qry, y_qry):
                sx, sy, qx, qy = to_device(sx, sy, qx, qy, device=self.args.device)
                logits = self.model(sx, sy, qx)
                loss = -F.log_softmax(logits, dim=-1)[torch.arange(qy.size(0)), qy].mean()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                with torch.no_grad():
                    self.tr_stats.update_loss(loss * qy.size(0), qy.size(0))
                    self.tr_stats.update_acc((logits.argmax(dim=-1) == qy).sum().item(), qy.size(0))
                    self.tr_stats.update_nll(logits, qy)
                    self.tr_stats.update_ece(logits, qy)
                    self.tr_stats.update_aupr_auroc(qy, logits)

                self.iter += 1
                if self.iter % self.args.val_interval == 0:
                    return

    def test(self) -> None:
        self.model.eval()
        with torch.no_grad():
            for (x_spt, y_spt, x_qry, y_qry) in tqdm(self.valset, ncols=75, leave=False):
                for (sx, sy, qx, qy) in zip(x_spt, y_spt, x_qry, y_qry):
                    sx, sy, qx, qy = to_device(sx, sy, qx, qy, device=self.args.device)
                    logits = self.model(sx, sy, qx)
                    loss = -F.log_softmax(logits, dim=-1)[torch.arange(qy.size(0)), qy].mean()
                    with torch.no_grad():
                        self.te_stats.update_loss(loss * qy.size(0), qy.size(0))
                        self.te_stats.update_acc((logits.argmax(dim=-1) == qy).sum().item(), qy.size(0))
                        self.te_stats.update_nll(logits, qy)
                        self.te_stats.update_ece(logits, qy)
                        self.te_stats.update_aupr_auroc(qy, logits)

    def load_model(self, path: str) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        if os.path.exists(sd_path):
            saved = torch.load(sd_path)
            self.iter = saved["iter"]
            self.best_acc = saved["best_acc"]

            self.model.load_state_dict(saved["state_dict"])
            self.optimizer.load_state_dict(saved["optimizer"])
            self.scheduler.load_state_dict(saved["scheduler"])
            self.finished = saved["finished"]
            print(f"loaded saved model: {self.iter=} {self.finished=}")

    def save_model(self, path: str, finished: bool = False) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        save = dict(
            iter=self.iter,
            state_dict=self.model.state_dict(),
            optimizer=self.optimizer.state_dict(),
            scheduler=self.scheduler.state_dict(),
            finished=finished,
            best_acc=self.best_acc
        )
        torch.save(save, sd_path)

    def log_train_stats(self, path: str) -> Dict[str, float]:
        self.tr_stats.log_stats(os.path.join(path, f"train-run-{self.args.run}.csv"))
        names, values = self.te_stats.log_stats(os.path.join(path, f"val-run-{self.args.run}.csv"))

        msg = f"({self.args.run}) iter: {self.iter}/{self.args.metatrain_iters} "
        for i, (n, v) in enumerate(zip(names, values)):
            msg += f"{n}: {v:.4f} "

        self.log(msg)
        return {n: v for (n, v) in zip(names, values)}

    def log_test_stats(self, path: str, test_name: str = "test") -> Dict[str, float]:
        names, values = self.te_stats.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}.csv"))
        self.log(" ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))
        return {n: v for (n, v) in zip(names, values)}

    def log(self, msg: str) -> None:
        self.args.logger.info(msg)


if __name__ == "__main__":
    parser = ArgumentParser("argument parser for MAML")

    parser.add_argument("--dataset", type=str, default="miniimagenet", choices=["miniimagenet", "omniglot"], help="the dataset to use")
    parser.add_argument("--data-root", type=str, help="the dataset root directory")
    parser.add_argument("--ood-test", type=str2bool, default=False, help="if True, use random classes for the query set")
    parser.add_argument("--corrupt-test", type=str2bool, default=False, help="if True, use corrupted query set")
    parser.add_argument("--num-workers", type=int, default=4, help="the number of workers for the dataloader")
    parser.add_argument("--metatrain-iters", type=int, default=60000, help="the number of metatrain iters to run")
    parser.add_argument("--lr-steps", type=int, nargs="+", default=[30000])
    parser.add_argument("--lr", type=float, default=1e-3, help="meta optimization learn rate")
    parser.add_argument("--lr-gamma", type=float, default=0.5)
    parser.add_argument("--run", type=int, default=0, help="independent run number")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
    parser.add_argument("--mode", type=str, choices=["train", "test"])
    parser.add_argument("--model", type=str, choices=["CNN4"], default="CNN4", help="the underlying model for protonet")
    parser.add_argument("--save-best-val", type=str2bool, default=False, help="whether or not to apply early stopping with validation")
    parser.add_argument("--n-way", type=int, default=5, help="the number of classes in each task")
    parser.add_argument("--k-shot", type=int, default=5, help="the number of exampels of each class")
    parser.add_argument("--train-query-shots", type=int, default=15, help="the number of testshots to use")
    parser.add_argument("--val-query-shots", type=int, default=15, help="the number of testshots to use")
    parser.add_argument("--val-interval", type=int, default=500, help="validation interval during training")
    parser.add_argument("--batch-size", type=int, default=16, help="metatrain batch size")

    args = parser.parse_args()
    args.logger = set_logger("INFO")
    args.device = torch.device(f"cuda:{args.gpu}")
    args.get_val = True

    # seed before doing anything else with the dataloaders
    seed(args.run)

    trainset, valset, testset = get_dataset(args)

    print(args.dataset)
    model_deref = {"CNN4": ProtonetCNN4}
    ch = {"omniglot": 1, "miniimagenet": 3}
    ds_deref = {"train": [trainset, valset], "test": [trainset, testset]}

    train, val = ds_deref[args.mode]
    trainer = ProtonetTrainer(args, model_deref[args.model](ch[args.dataset], 64, 64, args.n_way), trainset=train, valset=val)

    if args.mode == "train":
        trainer.fit()
        trainer.log("finished training")
    elif args.mode == "test":
        trainer.load_model(trainer.models_path)
        if not trainer.finished:
            raise ValueError("running tests on an unfinished model")

        trainer.test()
        trainer.log_test_stats(trainer.results_path, test_name=get_test_name(args))
        trainer.log("finished testing")
    else:
        raise NotImplementedError()
