import argparse
import logging
import os
import random
from argparse import Namespace
from typing import Any, Union

import numpy as np  # type: ignore
import torch
from data.get import get_dataset
from matplotlib import cm
from matplotlib import pyplot as plt  # type: ignore
from mpl_toolkits.axes_grid1 import ImageGrid  # type: ignore
from sklearn.manifold import TSNE
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils import get_color, to_cpu

from sngp.linear import sngp_linear
from sngp.trainer import SNGPTrainer

T = torch.Tensor


Arr = Any
FT = Union[float, T]
CM = cm.cool
name_deref = {"sngp_linear": "SNGP"}


class SNGPToy(SNGPTrainer):
    def __init__(self, args: Namespace, model: nn.Module, trainset: DataLoader, valset: DataLoader):
        super().__init__(args, model, trainset, valset)
        self.track_logits: bool = False

    def get_final_logits(self, x: T, y: T, ood: T) -> Any:
        with torch.no_grad():
            yhat, id_logits = model.mc(x)
            yhat, id_logits = to_cpu(yhat, id_logits)

        yhat = torch.clamp(yhat, 1e-45)
        id_entropy = -(yhat * torch.log(yhat)).sum(dim=-1)
        id_logits, y = to_cpu(id_logits, y)
        self.tr_stats.update_loss(F.cross_entropy(id_logits, y) * y.size(0), y.size(0))
        self.tr_stats.update_acc((id_logits.argmax(dim=-1) == y).sum().item(), y.size(0))
        self.tr_stats.update_nll(id_logits, y)
        self.tr_stats.update_ece(id_logits, y, softmaxxed=True)
        self.tr_stats.update_aupr_auroc(y, id_logits)

        ood_yhat, ood_logits = model.mc(ood)
        ood_logits, ood_yhat = to_cpu(ood_logits, ood_yhat)
        preds = torch.clamp(ood_yhat, 1e-45)
        ood_entropy = -(preds * torch.log(preds)).sum(dim=-1)
        ood_y = torch.randint(0, self.args.classes, (ood_logits.size(0),))
        self.te_stats.update_loss(F.cross_entropy(ood_logits, ood_y) * ood_y.size(0), ood_y.size(0))
        self.te_stats.update_acc((ood_logits.argmax(dim=-1) == ood_y).sum().item(), ood_y.size(0))
        self.te_stats.update_nll(ood_logits, ood_y)
        self.te_stats.update_ece(ood_logits, ood_y, softmaxxed=True)
        self.te_stats.update_aupr_auroc(ood_y, ood_logits)

        return id_entropy, ood_entropy, None

    def plot_metatest(self, x: T, y: T) -> None:
        self.model.eval()
        ood = self.valset.dataset.sample_uniform()  # type: ignore
        id_entropy, ood_entropy, unc_color = self.get_final_logits(x.cuda(), y.cuda(), ood.cuda())

        xmin, xmax, ymin, ymax = ood[:, 0].min(), ood[:, 0].max(), ood[:, 1].min(), ood[:, 1].max()
        fig = plt.figure(figsize=(7, 6))
        axes = ImageGrid(fig, 111, nrows_ncols=(1, 1), axes_pad=0.05, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.05)
        ax = axes[0]

        ax.scatter(x[:, 0], x[:, 1], c=[get_color(v.item()) for v in y], s=30, edgecolors=(0, 0, 0, 0.5), linewidths=2.0)
        ax.scatter(x[:, 0], x[:, 1], c=[get_color(v.item()) for v in y], marker='*', s=20)

        x, y = 0.0, 2.5
        ood_stats = self.te_stats.get_stats()
        id_stats = self.tr_stats.get_stats()

        text = f"Accuracy: {id_stats['accuracy']:.2f}\nID Entropy: {id_entropy.mean():.2f}\nOOD Entropy: {ood_entropy.mean():.2f}\nECE: {ood_stats['ece']:.2f}"
        ax.text(x, y, horizontalalignment="center", color=[0, 0, 0, 0.7], fontweight="bold", fontsize=14, verticalalignment="top", s=text)
        ax.axis("off")

        ent = CM(ood_entropy)[:, :3]
        ent = ent.reshape(100, 100, 3)
        im = ax.imshow(np.transpose(ent, (1, 0, 2)), origin="lower", cmap=CM, extent=(xmin, xmax, ymin, ymax))

        cbar = axes[-1].cax.colorbar(im)
        cbar.ax.axis("off")

        fig.tight_layout()
        fig.savefig(os.path.join(self.results_path, f"{self.trainset.dataset.name}-toy.pdf"))  # type: ignore
        fig.savefig(os.path.join(self.results_path, f"{self.trainset.dataset.name}-toy.png"))  # type: ignore

    def plot_layerwise_latent(self, x: T, y: T) -> None:
        self.model.eval()
        self.model.save_features = True  # type: ignore

        ood = self.valset.dataset.sample_uniform()  # type: ignore
        ood = ood[torch.randperm(ood.size(0))[:512]]

        fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(7 * 4, 6))
        all_x = torch.cat((x, ood))

        # the model will call the base layers which will save the features at each level. once we have the yhats for every
        # instance in all_x we can calculate the entropy of each one to make a color for it
        _, ood_entropy, _ = self.get_final_logits(x.cuda(), y.cuda(), all_x.cuda())

        for j, (ft, ax) in enumerate(zip(self.model.features, axes)):  # type: ignore
            ft2d = TSNE(n_components=2).fit_transform(ft.cpu().numpy())
            spt_size, qry_size = x.size(0), x.size(0)
            x, x, oodx = ft2d[:spt_size], ft2d[spt_size : spt_size + qry_size], ft2d[spt_size + qry_size :]

            # entropy = entropy / entropy.max()
            ood_ent = ood_entropy[spt_size + qry_size :]
            # entropy = (entropy - entropy.min()) / (entropy.max() - entropy.min())
            ood_ent = (ood_ent - ood_ent.min()) / (ood_ent.max() - ood_ent.min())

            x, y = 0.0, 5.5
            text = f"{name_deref[args.variant]} layer: {j}"
            ax.text(x, y, horizontalalignment="center", color=[0, 0, 0, 0.7], fontweight="bold", fontsize=14, verticalalignment="top", s=text)

            ax.scatter(oodx[:, 0], oodx[:, 1], label="OOD", c=CM(ood_ent), cmap=CM)
            fig.colorbar(cm.ScalarMappable(cmap=CM), ax=ax)

            ax.scatter(x[:, 0], x[:, 1], color=get_color(4), marker='*', s=100, label="query set")
            ax.scatter(x[:, 0], x[:, 1], label="support set", marker="p", s=100, edgecolors=(0, 0, 0, 0.5), color=get_color(3))
            ax.legend()
            leg = ax.get_legend()
            leg.legendHandles[0].set_color('black')

        fig.tight_layout()
        fig.savefig(os.path.join(self.results_path, f"{self.trainset.dataset.name}-layerwise-latent.pdf"))  # type: ignore
        fig.savefig(os.path.join(self.results_path, f"{self.trainset.dataset.name}-layerwise-latent.png"))  # type: ignore
        plt.close()

    def plot_latent2d(self, x: T, y: T) -> None:
        self.model.eval()
        ood = self.valset.dataset.sample_uniform()  # type: ignore
        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 6))

        x2d = self.model.down_project(x.cuda()).cpu().detach()  # type: ignore
        ood2d = self.model.down_project(ood.cuda()).cpu().detach()  # type: ignore

        ax.scatter(ood2d[:, 0], ood2d[:, 1], color=get_color(6), marker="o", label="OOD")
        ax.scatter(x2d[:, 0], x2d[:, 1], color=get_color(3), marker='*', s=100, edgecolors=(0, 0, 0, 0.5), label="query set")
        ax.legend()

        fig.tight_layout()
        fig.savefig(os.path.join(self.results_path, f"{self.trainset.dataset.name}-latent-projection-toy.pdf"))  # type: ignore
        fig.savefig(os.path.join(self.results_path, f"{self.trainset.dataset.name}-latent-projection-toy.png"))  # type: ignore

    def set_track_logits(self, val: bool) -> None:
        self.track_logits = val

    def experiment(self, fname_prepend: str = "") -> None:
        self.load_model(self.models_path)
        if not self.finished:
            self.set_track_logits(False)  # always turn on logit tracking for experiments
            for epoch in range(self.args.epochs):
                self.train()

            self.save_model(self.models_path, finished=True)

        self.model.init_sigma_lambda()
        for (x, y) in self.trainset:
            x, y = x.to(self.args.device), y.to(self.args.device)
            self.model(x, update_prec=True, y=y)
        self.model.compute_cov()

        for (x, y) in self.valset:
            break

        with torch.no_grad():
            # self.plot_latent2d(x, y)
            self.plot_metatest(x, y)
            # self.plot_layerwise_latent(x, y)
            return


if __name__ == "__main__":
    parser = argparse.ArgumentParser("parser for SNGP based models")

    parser.add_argument("--variant", type=str, default="sngp", help="the variant of the model to run")
    parser.add_argument("--run", type=int, default=0, help="the run number")
    parser.add_argument("--comment", type=str, default="", help="comment for experiment")
    parser.add_argument("--lr", type=float, default=1e-2, help="the learning rate of the optimizer")
    parser.add_argument("--epochs", type=int, default=500, help="the number of epochs to train for")
    parser.add_argument("--lr-steps", type=int, nargs="+", default=[250, 350], help="learning rate steps for the scheduler")
    parser.add_argument("--lr-gamma", type=float, default=0.2, help="scheduler LR multiplicative factor")
    parser.add_argument("--p", type=float, default=0.01, help="the dropout rate")
    parser.add_argument("--weight-decay", type=float, default=0.0, help="the weight decay for the optimizer")
    parser.add_argument("--root", type=str, default=os.path.join("/", "home", "datasets"))
    parser.add_argument("--ctype", type=str, default="scalar", help="the scaling constant type for spectral normalization")
    parser.add_argument("--batch-size", type=int, default=128, help="the batch size")
    parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum for the inner loop SGD")
    parser.add_argument("--num-workers", type=int, default=8, help="the number of workers for the dataloader")
    parser.add_argument("--experiment-name", type=str, help="the name of an experiment to run (eval on trained model)")

    args = parser.parse_args()
    args.bins = 15  # add a arg for this later if needed
    model: nn.Module

    logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", level="INFO", datefmt="%Y-%m-%d %H:%M:%S")

    args.device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.log = logging.getLogger()

    for i, (ds, classes, s) in enumerate(zip(["toy-moons", "toy-circles", "toy-gaussians"], [2, 2, 10], [1e-4, 1e-8, 1e-1])):
        torch.manual_seed(i)
        random.seed(i)
        np.random.seed(i)

        args.classes = classes
        args.dataset = ds
        args.seed = i

        trainset, _, testset = get_dataset(args)
        trainerclass = SNGPToy
        model_deref = {"sngp": sngp_linear}

        model = model_deref[args.variant](12, 2, 128, args.p, args.classes, ctype="none", s=s)
        trainer = trainerclass(args, model, trainset, testset)
        print(f"starting toy experiment: {ds} on model: {model.name}")
        trainer.experiment()
