import os
from argparse import ArgumentParser, Namespace
from functools import partial
from typing import Dict

import torch
from base import Algorithm
from data.get import get_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader  # type: ignore
from tqdm import tqdm  # type: ignore
from utils import Stats, seed, set_logger, str2bool  # type: ignore

from sngp.model import SNGP, SNGP_WideResNet28_10_cifar

T = torch.Tensor


class SNGPTrainer(Algorithm):
    def __init__(self, args: Namespace, model: SNGP, trainset: DataLoader, valset: DataLoader):
        super().__init__()

        self.args = args
        self.model = model.to(self.args.device)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        self.trainset = trainset
        self.valset = valset
        self.best_acc = 0.
        self.epoch = 0
        self.finished = False

        self.results_path = os.path.join("results", f"{args.dataset}", f"{self.model.name}")
        self.models_path = os.path.join(self.results_path, "models")
        for d in [self.results_path, self.models_path]:
            os.makedirs(d, exist_ok=True)

        self.tr_stats = Stats(["accuracy", "loss", "nll", "ece", "aupr", "auroc"])
        self.te_stats = Stats(["accuracy", "loss", "nll", "ece", "aupr", "auroc"])

    def fit(self) -> None:
        self.load_model(self.models_path)
        if self.finished:
            self.log("called fit() on a model which has finished training")
            return

        for epoch in range(self.epoch, self.args.epochs):
            # re-initialize sigma and lambda here because during fitting we want to make sure they are
            # blank for the validation phase, but they should be saved in the current model for future testing
            self.model.init_sigma_lambda()
            self.train()
            self.test()
            self.scheduler.step()
            self.epoch += 1

            stats = self.log_train_stats(self.results_path)
            acc = stats["accuracy"]
            if not args.save_best_val:
                self.save_model(self.models_path)
                continue

            if acc > self.best_acc:
                self.log(f"({self.args.run}) new best accuracy: {acc:.4f}")
                self.best_acc = acc
                self.save_model(self.models_path)

        self.log("finished training, marking last saved model as finished")
        self.load_model(self.models_path)
        self.save_model(self.models_path, finished=True)

    def train(self) -> None:
        self.model.train()
        for (x, y) in tqdm(self.trainset, ncols=75, leave=False):
            x, y = x.to(self.args.device), y.to(self.args.device)
            self.optimizer.zero_grad()
            logit = self.model(x)
            loss = F.cross_entropy(logit, y)
            loss.backward()
            self.optimizer.step()

            with torch.no_grad():
                self.tr_stats.update_loss(loss * y.size(0), y.size(0))
                self.tr_stats.update_acc((logit.argmax(dim=-1) == y).sum().item(), y.size(0))
                self.tr_stats.update_nll(logit, y)
                self.tr_stats.update_ece(logit, y)
                self.tr_stats.update_aupr_auroc(y, logit)

    def test(self) -> None:
        self.model.eval()

        with torch.no_grad():
            for (x, y) in tqdm(self.trainset, ncols=75, leave=False):
                x, y = x.to(self.args.device), y.to(self.args.device)
                _ = self.model(x, update_prec=True, y=y)
            self.model.compute_cov()

            for it, (x, y) in enumerate(tqdm(self.valset, ncols=75, leave=False)):
                yhat, logit = self.model.mc(x.to(self.args.device), 1000)  # fast sampling because we are only sampling the softmax logits from a Gaussian
                loss = F.cross_entropy(logit, y)

                self.te_stats.update_loss(loss * y.size(0), y.size(0))
                self.te_stats.update_acc((logit.argmax(dim=-1) == y).sum().item(), y.size(0))
                self.te_stats.update_nll(logit, y)
                self.te_stats.update_ece(yhat, y, softmaxxed=True)
                self.te_stats.update_aupr_auroc(y, logit)

    def load_model(self, path: str) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        if os.path.exists(sd_path):
            saved = torch.load(sd_path)
            self.epoch = saved["epoch"]
            self.best_acc = saved["best_acc"]

            self.model.load_state_dict(saved["state_dict"])
            self.optimizer.load_state_dict(saved["optimizer"])
            self.scheduler.load_state_dict(saved["scheduler"])
            self.finished = saved["finished"]
            print(f"loaded saved model: {self.epoch=} {self.finished=}")

    def save_model(self, path: str, finished: bool = False) -> None:
        sd_path = os.path.join(path, f"{self.args.run}.pt")
        save = dict(
            epoch=self.epoch,
            state_dict=self.model.state_dict(),
            optimizer=self.optimizer.state_dict(),
            scheduler=self.scheduler.state_dict(),
            finished=finished,
            best_acc=self.best_acc
        )
        torch.save(save, sd_path)

    def log_train_stats(self, path: str) -> Dict[str, float]:
        self.tr_stats.log_stats(os.path.join(path, f"train-run-{self.args.run}.csv"))
        names, values = self.te_stats.log_stats(os.path.join(path, f"val-run-{self.args.run}.csv"))

        msg = f"({self.args.run}) epoch: {self.epoch}/{self.args.epochs} "
        for i, (n, v) in enumerate(zip(names, values)):
            msg += f"{n}: {v:.4f} "

        self.log(msg)
        return {n: v for (n, v) in zip(names, values)}

    def log_test_stats(self, path: str, test_name: str = "test") -> Dict[str, float]:
        names, values = self.te_stats.log_stats(os.path.join(path, f"{test_name}-run-{self.args.run}.csv"))
        self.log(" ".join([f"{n} {v:.4f}" for (n, v) in zip(names, values)]))
        return {n: v for (n, v) in zip(names, values)}

    def log(self, msg: str) -> None:
        self.args.logger.info(msg)


if __name__ == "__main__":
    parser = ArgumentParser("argument parser for MAML")

    parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100"], help="the dataset to use")
    parser.add_argument("--data-root", type=str, help="the dataset root directory")
    parser.add_argument("--ood-test", type=str2bool, default=False, help="if True, use random classes for the query set")
    parser.add_argument("--corrupt-test", type=str2bool, default=False, help="if True, use corrupted query set")
    parser.add_argument("--num-workers", type=int, default=4, help="the number of workers for the dataloader")
    parser.add_argument("--epochs", type=int, default=250, help="the number of metatrain iters to run")
    parser.add_argument("--batch-size", type=int, default=128, help="batch size for training")
    parser.add_argument("--lr-steps", type=int, nargs="+", default=[60, 120, 160])
    parser.add_argument("--lr-gamma", type=float, default=0.2)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight_decay", type=float, default=5e-4)
    parser.add_argument("--run", type=int, default=0, help="independent run number")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
    parser.add_argument("--lr", type=float, default=1e-1, help="meta optimization learn rate")
    parser.add_argument("--mode", type=str, choices=["train", "test"])
    parser.add_argument("--model", type=str, choices=["wide-sn-resnet-28-10-cifar"])
    parser.add_argument("--filterwise-dropout", type=str2bool, default=True, help="filterwise dropout for WideResNet model")
    parser.add_argument("--p", type=float, default=0.1, help="dropout probability for WideResNet models")
    parser.add_argument("--save-best-val", type=str2bool, default=False, help="whether or not to apply early stopping with validation")

    args = parser.parse_args()
    args.logger = set_logger("INFO")
    args.device = torch.device(f"cuda:{args.gpu}")
    args.get_val = True
    args.val_pct = 0.1

    # seed before doing anything else with the dataloaders
    seed(args.run)

    trainset, valset, testset = get_dataset(args)

    n_class = {"cifar10": 10, "cifar100": 100}
    model_deref = {
        "wide-sn-resnet-28-10-cifar": partial(
            SNGP_WideResNet28_10_cifar,
            sngp_kwargs=dict(num_classes=n_class[args.dataset]),
            resnet_kwargs=dict(p=args.p, filterwise_dropout=args.filterwise_dropout, num_classes=n_class[args.dataset])
        )
    }
    ds_deref = {"train": [trainset, valset], "test": [trainset, testset]}

    train, val = ds_deref[args.mode]
    trainer = SNGPTrainer(args, model_deref[args.model](), trainset=train, valset=val)

    if args.mode == "train":
        trainer.fit()
        trainer.log("finished training")
    elif args.mode == "test":
        trainer.load_model(trainer.models_path)
        trainer.test()
        trainer.log_test_stats(trainer.results_path, test_name="standard")
        trainer.log("finished testing")
    else:
        raise NotImplementedError()
