import argparse
import logging
import os
import random
from argparse import Namespace
from functools import partial
from typing import Any, Union

import numpy as np  # type: ignore
import torch
from data.get import get_dataset
from matplotlib import cm  # type: ignore
from matplotlib import pyplot as plt  # type: ignore
from matplotlib.lines import Line2D  # type: ignore
from mpl_toolkits.axes_grid1 import ImageGrid  # type: ignore
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import Stats, get_color, set_logger, str2bool, to_cpu, to_device

from mahalanobis.models import (proto_ddu_linear, proto_mahalanobis_linear,
                                proto_sngp_linear, protonet_linear)
from mahalanobis.trainer import MahalanobisTrainer

T = torch.Tensor


Arr = Any
FT = Union[float, T]
CM = cm.viridis


name_deref = {"protonet": "ProtoNet", "proto-mahalanobis": "GTNet (Ours)", "proto-sngp": "ProtoSNGP", "proto-ddu": "ProtoDDU"}


class ProtoToy(MahalanobisTrainer):
    def __init__(self, args: Namespace, model: nn.Module, trainset: DataLoader, valset: DataLoader):
        super().__init__(args, model, trainset, valset)
        self.track_logits: bool = False

    def get_final_logits(self, sx: T, sy: T, qx: T, qy: T, ood: T, inference_style: str) -> Any:
        self.id_stats = Stats(["accuracy", "nll", "ece", "softmax_entropy"])
        self.ood_stats = Stats(["nll", "ece", "softmax_entropy"])

        with torch.no_grad():
            # qx_ood = torch.rand(qx.size(0), qx.size(1), device=sx.device) * 6 - 3
            # self.model.tune_task(sx, sy, qx_ood, n_way=args.n_way, k_shot=args.k_shot)
            id_preds, id_log_preds, id_energy = model.inference(sx, sy, qx, n_way=args.n_way, k_shot=args.k_shot, inference_style=inference_style)

            oods = torch.split(ood, ood.size(0) // 10)
            ood_preds, ood_log_preds, ood_energy = torch.Tensor().to(ood.device), torch.Tensor().to(ood.device), torch.Tensor().to(ood.device)
            for ood in oods:
                p, lp, e = model.inference(sx, sy, ood, n_way=args.n_way, k_shot=args.k_shot, inference_style=inference_style)
                ood_preds = torch.cat((ood_preds, p))
                ood_log_preds = torch.cat((ood_log_preds, lp))
                ood_energy = torch.cat((ood_energy, e))

        id_preds, id_energy, ood_preds, ood_energy = id_preds.detach(), id_energy.detach(), ood_preds.detach(), ood_energy.detach()

        ood_y = torch.randint(0, self.args.n_way, (ood_preds.size(0),))
        ood_preds, id_preds, qy, ood_y = to_cpu(ood_preds, id_preds, qy, ood_y)

        ood_preds = torch.clamp(ood_preds, 1e-45)
        id_preds = torch.clamp(id_preds, 1e-45)

        softmaxxed = True  # all models should be softmaxxing the preds before returning to maintain consistency between the sampling methods
        self.id_stats.update_acc((id_preds.argmax(dim=-1) == qy).sum().item(), qy.size(0))
        self.id_stats.update_nll(id_preds, qy, softmaxxed=softmaxxed)
        self.id_stats.update_ece(id_preds, qy, softmaxxed=softmaxxed)
        self.id_stats.update_softmax_entropy(id_preds, qy.size(0), softmaxxed=softmaxxed)

        self.ood_stats.update_nll(ood_preds, ood_y, softmaxxed=softmaxxed)
        self.ood_stats.update_ece(ood_preds, ood_y, softmaxxed=softmaxxed)
        self.ood_stats.update_softmax_entropy(ood_preds, ood_y.size(0), softmaxxed=softmaxxed)

        ood_entropy = -(ood_preds * torch.log(ood_preds)).sum(dim=-1)

        return ood_entropy, ood_energy.detach().cpu()

    def plot_metatest(self, x_spt: T, y_spt: T, x_qry: T, y_qry: T) -> None:
        self.model.eval()
        for i, (sx, sy, qx, qy) in enumerate(zip(x_spt, y_spt, x_qry, y_qry)):
            self.args.n_way = sx.n_way if hasattr(sx, "n_way") else self.args.n_way
            self.args.k_shot = sx.k_shot if hasattr(sx, "k_shot") else self.args.k_shot
            for inference_style in ["distance", "softmax-sample"]:
                # plot the plain distance plots
                ood = self.valset.dataset.sample_uniform()  # type: ignore
                ood_entropy, ood_energy = self.get_final_logits(sx.cuda(), sy.cuda(), qx.cuda(), qy.cuda(), ood.cuda(), inference_style=inference_style)

                xmin, xmax, ymin, ymax = ood[:, 0].min(), ood[:, 0].max(), ood[:, 1].min(), ood[:, 1].max()
                fig = plt.figure(figsize=(7, 6))
                logit_ax = ImageGrid(fig, 111, nrows_ncols=(1, 1), axes_pad=0.05, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.05)[0]

                x, y = 0.0, 2.8
                id_stats = self.id_stats.get_stats()
                ood_stats = self.ood_stats.get_stats()

                text = (
                    f"Accuracy: {id_stats['accuracy']:.2f}\nNLL {id_stats['nll']:.2f}"
                    f"\nENT ID/OOD: {id_stats['softmax_entropy']:.2f} / {ood_stats['softmax_entropy']:.2f}"
                    f"\nECE ID/OOD {id_stats['ece']:.2f} / {ood_stats['ece']:.2f}"
                )
                logit_ax.text(x, y, horizontalalignment="center", color=[1, 1, 1, 0.7], fontweight="bold", fontsize=14, verticalalignment="top", s=text)
                logit_ax.axis("off")

                logit_ax.scatter(sx[:, 0], sx[:, 1], c=[get_color(v.item()) for v in sy], s=30, edgecolors=(0, 0, 0, 0.5), linewidths=2.0)
                logit_ax.scatter(qx[:, 0], qx[:, 1], c=[get_color(v.item()) for v in qy], marker='*', s=20)
                entmin, entmax = ood_entropy.min().item(), ood_entropy.max().item()

                # unscaled_ood_entropy = (ood_entropy - entmin) / (entmax - entmin)
                unscaled_ood_entropy = ood_entropy / entmax
                ent = CM(unscaled_ood_entropy)[:, :3]
                ent = ent.reshape(100, 100, 3)
                im = logit_ax.imshow(np.transpose(ent, (1, 0, 2)), origin="lower", cmap=CM, extent=(xmin, xmax, ymin, ymax))

                legend_elements = [
                    Line2D([0], [0], marker='o', color='w', label='train', markerfacecolor='black', markersize=10),
                    Line2D([0], [0], marker='*', color='w', label='test', markerfacecolor='black', markersize=10),
                ]
                logit_ax.legend(handles=legend_elements)

                cbar = logit_ax.cax.colorbar(im)
                cbar.ax.axis("off")

                fig.tight_layout()
                fig.savefig(os.path.join(self.results_path, f"{i}-{self.trainset.dataset.name}-{self.model.name}-{inference_style}-entropy-toy.pdf"))  # type: ignore
                fig.savefig(os.path.join(self.results_path, f"{i}-{self.trainset.dataset.name}-{self.model.name}-{inference_style}-entropy-toy.png"))  # type: ignore
                plt.close()

                xmin, xmax, ymin, ymax = ood[:, 0].min(), ood[:, 0].max(), ood[:, 1].min(), ood[:, 1].max()
                fig = plt.figure(figsize=(7, 6))
                energy_ax = ImageGrid(fig, 111, nrows_ncols=(1, 1), axes_pad=0.05, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.05)[0]

            # plot the energy plots
            text = "Energy"
            energy_ax.scatter(sx[:, 0], sx[:, 1], c=[get_color(v.item()) for v in sy], s=30, edgecolors=(0, 0, 0, 0.5), linewidths=2.0)
            energy_ax.scatter(qx[:, 0], qx[:, 1], c=[get_color(v.item()) for v in qy], marker='*', s=20)
            energy_ax.text(x, y, horizontalalignment="center", color=[1, 1, 1, 0.7], fontweight="bold", fontsize=14, verticalalignment="top", s=text)
            energy_ax.axis("off")
            emax = ood_energy.max().item()
            # ood_energy = (ood_energy - emin) / (emax - emin)
            ood_energy = ood_energy / emax

            # we have to flip the sign of the energy, because lower energies correspond to more confident predictions which
            # is the inverse of how confidence works in the log probability scale
            ent = CM(ood_energy)[:, :3]
            ent = ent.reshape(100, 100, 3)
            im = energy_ax.imshow(np.transpose(ent, (1, 0, 2)), origin="lower", cmap=CM, extent=(xmin, xmax, ymin, ymax))

            legend_elements = [
                Line2D([0], [0], marker='o', color='w', label='train', markerfacecolor='black', markersize=10),
                Line2D([0], [0], marker='*', color='w', label='test', markerfacecolor='black', markersize=10),
            ]
            energy_ax.legend(handles=legend_elements)

            cbar = energy_ax.cax.colorbar(im)
            cbar.ax.axis("off")

            fig.tight_layout()
            fig.savefig(os.path.join(self.results_path, f"{i}-{self.trainset.dataset.name}-{self.model.name}-energy-toy.pdf"))  # type: ignore
            fig.savefig(os.path.join(self.results_path, f"{i}-{self.trainset.dataset.name}-{self.model.name}-energy-toy.png"))  # type: ignore
            plt.close()

            self.tr_stats.zero()
            self.te_stats.zero()

            # plot the covariance plots
            if "protonet" not in self.args.model:
                covs = self.model.get_cov(sx.cuda(), sy.cuda(), n_way=args.n_way, k_shot=args.k_shot)
                covs = covs.detach().cpu()
                covpath = os.path.join(self.results_path, "covariances")
                os.makedirs(covpath, exist_ok=True)

                fig = plt.figure(figsize=(7, 6))
                ax = ImageGrid(fig, 111, nrows_ncols=(1, 1), axes_pad=0.05, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.05)[0]
                for j in range(covs.size(0)):
                    ax.imshow(covs[j], cmap="viridis")
                    fig.tight_layout()
                    ax.axis("off")
                    cbar = ax.cax.colorbar(im)
                    cbar.ax.axis("off")
                    fig.savefig(os.path.join(covpath, f"{i}-{self.trainset.dataset.name}-{self.model.name}-class-{j}.pdf"))
                    fig.savefig(os.path.join(covpath, f"{i}-{self.trainset.dataset.name}-{self.model.name}-class-{j}.png"))
                    plt.close()

    def set_track_logits(self, val: bool) -> None:
        self.track_logits = val

    def test_ood_toy(self) -> None:
        self.model.eval()
        stats = ["accuracy", "nll", "ece", "softmax_entropy", "aupr", "auroc"]
        logs = [("id_ood_entropy", os.path.join(self.results_path, f"id-ood-entropy-run-{self.args.run}.txt"))]
        self.ood_te_stats = Stats(stats, logs)

        with torch.no_grad():
            for (x_spt, y_spt, x_qry, y_qry) in tqdm(self.valset, ncols=75, leave=False):
                for (sx, sy, qx, _) in zip(x_spt, y_spt, x_qry, y_qry):
                    # the toy datasets put the n_way k_shot into the support tensor so we need to get that value if it exists
                    n_way, k_shot = sx.n_way if hasattr(sx, "n_way") else self.args.n_way, sx.k_shot if hasattr(sx, "k_shot") else self.args.k_shot
                    qx_ood = self.valset.dataset.sample_uniform()  # type: ignore
                    qy_ood = torch.randint(0, self.args.n_way, (qx_ood.size(0),))

                    qx_s = qx.size(0)
                    qx = torch.cat((qx, qx_ood))

                    sx, sy, qx, qy_ood = to_device(sx, sy, qx, qy_ood, device=self.args.device)

                    oods = torch.split(qx, qx.size(0) // 10)
                    preds, log_preds, energy = torch.Tensor().to(qy_ood.device), torch.Tensor().to(qy_ood.device), torch.Tensor().to(qy_ood.device)
                    for ood in oods:
                        p, lp, e = model.inference(sx, sy, ood, n_way=n_way, k_shot=k_shot, inference_style=self.args.inference_style)
                        preds = torch.cat((preds, p))
                        log_preds = torch.cat((log_preds, lp))
                        energy = torch.cat((energy, e))

                    _, preds_ood = preds[:qx_s], preds[qx_s:]
                    energy_id, energy_ood = energy[:qx_s], energy[qx_s:]

                    with torch.no_grad():
                        self.ood_te_stats.update_acc((preds_ood.argmax(dim=-1) == qy_ood).sum().item(), qy_ood.size(0))
                        self.ood_te_stats.update_nll(preds_ood, qy_ood, softmaxxed=True)
                        self.ood_te_stats.update_ece(preds_ood, qy_ood, softmaxxed=True)
                        self.ood_te_stats.update_aupr_auroc(torch.cat((torch.zeros(energy_id.size(0)), torch.ones(energy_ood.size(0)))), energy)
                        self.ood_te_stats.update_softmax_entropy(preds_ood, qy_ood.size(0), softmaxxed=True)

    def experiment(self, fname_prepend: str = "") -> None:
        self.load_model(self.models_path)
        if not self.finished:
            self.set_track_logits(False)  # always turn on logit tracking for experiments
            for epoch in range(self.args.epochs):
                self.train()

            self.save_model(self.models_path, finished=True)

        if not self.model.tuned:
            self.tune()
            self.save_model(trainer.models_path, finished=True)  # will be saved with tuned flag set to true, temp buffer will be saved

        # reseed everything because the Gaussians dataset doesn't give the same shuffling for some reason.
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)

        if self.args.run == 0:
            for (x, y, xt, yt) in self.valset:
                self.plot_metatest(x, y, xt, yt)
                break

        for inference_style in ["distance", "softmax-sample"]:
            self.args.inference_style = inference_style

            self.test()
            self.log_test_stats(self.results_path, test_name=f"standard-{inference_style}")

            self.test_ood_toy()
            _, _ = self.ood_te_stats.log_stats(os.path.join(self.results_path, f"ood-{inference_style}-run-{self.args.run}.csv"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("parser for MAML based models")

    parser.add_argument("--model", type=str, default="proto-mahalanobis", help="the variant of the model to run")
    parser.add_argument("--run", type=int, default=0, help="the run number")
    parser.add_argument("--metatrain-iters", type=int, default=10000, help="the number of metatrain iters to run")
    parser.add_argument("--lr-steps", type=int, nargs='+', default=[4500], help="steps to do learning rate decay (divided by the standard batch size)")
    parser.add_argument("--comment", type=str, default="", help="comment for experiment")
    parser.add_argument("--pma-type", type=str, default="no-residual", choices=["multiplicative", "no-residual", "additive"], help="the type of residual connection in the PMA layer")
    parser.add_argument("--lr-gamma", type=float, default=0.5, help="learning rate decay gamma")
    parser.add_argument("--lr", type=float, default=1e-3, help="the learning rate of the optimizer")
    parser.add_argument("--p", type=float, default=0.1, help="the dropout rate")
    parser.add_argument("--t", type=float, default=1.0, help="the temperature parameter for logit normal softmax sampling")
    parser.add_argument("--weight-decay", type=float, default=0.0, help="the weight decay for the optimizer")
    parser.add_argument("--root", type=str, default=os.path.join("/", "home", "datasets"))
    parser.add_argument("--epochs", type=int, default=50, help="the total number of trainign epochs to run")
    parser.add_argument("--ctype", type=str, default="scalar", help="the scaling constant type for spectral normalization")
    parser.add_argument("--batch-size", type=int, default=32, help="the batch size")
    parser.add_argument("--sigmoid-bias", type=str2bool, default=False, help="whether to put a bias on the sigmoid output for gradient stability")
    parser.add_argument("--ood-test", type=str2bool, default=False, help="OOD test needs to be in the args but it is not used in the toy examples")
    parser.add_argument("--corrupt-test", type=str2bool, default=False, help="corrupt test needs to be in the args but it is not used in the toy examples")
    parser.add_argument("--ood-training", type=str2bool, default=False, help="whether or not to incldue OOD training.")
    parser.add_argument("--forward-type", type=str, default="sigmoid", choices=["sigmoid", "exp", "softmax"], help="the forward pass type used during the training process")
    parser.add_argument("--val-interval", type=int, default=2000, help="validation interval duritn training")
    parser.add_argument("--momentum", type=float, default=0.99, help="SGD momentum for the inner loop SGD")
    parser.add_argument("--n-way", type=int, default=5, help="number of classes for training")
    parser.add_argument("--k-shot", type=int, default=5, help="number of classes for training")
    parser.add_argument("--num-workers", type=int, default=8, help="the number of workers for the dataloader")
    parser.add_argument("--experiment-name", type=str, help="the name of an experiment to run (eval on trained model)")
    parser.add_argument("--ood-test-class", action="store_true", help="whether or not to use OOD classes as test set for omniglot")
    parser.add_argument("--encoder-type", type=str, default="diag", help="the encoder type for the proto mahalanobis model")
    parser.add_argument("--rank", type=int, default=1, help="the rank for the proto mahalanobis model (low rank version only)")

    args = parser.parse_args()
    model: nn.Module
    args.logger = set_logger("INFO")
    args.inference_style = "softmax-sample"

    # the default option is in terms of metatrain iteratios, but we need to move the optimization into the batch level for learning
    # of the covariance amtrices, therefore the scehduler needs to take a step once a batch
    args.lr_steps = [v // args.batch_size for v in args.lr_steps]

    logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", level="INFO", datefmt="%Y-%m-%d %H:%M:%S")

    args.device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.log = logging.getLogger()

    for i, (ds, way, shot, test_shots) in enumerate(zip(["few-shot-toy-moons", "few-shot-toy-circles", "few-shot-toy-gaussian"], [2, 2, 10], [5, 5, 5], [15, 15, 15])):
        torch.manual_seed(args.run)
        random.seed(args.run)
        np.random.seed(args.run)

        args.n_way = way
        args.k_shot = shot
        args.test_shots = test_shots
        args.dataset = ds
        args.seed = args.run

        trainset, _, testset = get_dataset(args)
        trainerclass = ProtoToy

        h_dim = 64
        model_deref = {
            "protonet": partial(protonet_linear, 6, 2, h_dim, args.n_way, args.p, ctype="none", spectral=False, forward_type=args.forward_type),
            "protonet-sn": partial(protonet_linear, 6, 2, h_dim, args.n_way, args.p, ctype="none", spectral=True, forward_type=args.forward_type),
            "proto-sngp": partial(proto_sngp_linear, 6, 2, h_dim, args.n_way, args.p, ctype="none", spectral=True, forward_type=args.forward_type, gp_h_dim=64, gp_in_dim=64),
            "proto-ddu": partial(proto_ddu_linear, 6, 2, h_dim, args.n_way, args.p, ctype="none", spectral=True, forward_type=args.forward_type, cov_dim=64),
            "proto-mahalanobis": partial(
                proto_mahalanobis_linear, 6, 2, h_dim, args.n_way, args.p,
                ctype="none",
                spectral=True,
                encoder=args.encoder_type,
                rank=args.rank,
                beta=args.sigmoid_bias,
                forward_type=args.forward_type,
                pma_type=args.pma_type,
                t=args.t
            )
        }

        model = model_deref[args.model]()
        trainer = trainerclass(args, model, trainset, testset)
        print(f"\n\nstarting toy experiment: {args.dataset} on model: {model.name}\n\n")
        trainer.experiment()
