import copy
import os

import numpy as np  # type: ignore
import torch
from matplotlib import pyplot as plt  # type: ignore
from sklearn.manifold import TSNE  # type: ignore

from models.mc_dropout import MCDropoutHeteroConv
from models.set_encoder import LatentPerturber
from utils import (Args, ClassificationStats, ClassificationStatsTemp,
                   ModelSaveDict, dropout_eval, get_conv_args, softmax_entropy)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train r test set cannot be None")

    conv_args = get_conv_args(args)
    perturber = LatentPerturber(conv_args[-1][1]).to(args.device)
    phi_opt = torch.optim.Adam(perturber.parameters(), lr=args.beta_lr)

    model = MCDropoutHeteroConv(
        (args.x_dim, args.x_dim), args.h_dim, args.ps, args.y_dim, conv_args
    ).to(args.device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    criterion = torch.nn.NLLLoss()

    models: ModelSaveDict = {}
    best_stats: ClassificationStats = ClassificationStats(args.test_set.dataset, args)

    plot_x, plot_x_prime, plot_test = torch.Tensor(), torch.Tensor(), torch.Tensor()

    TEMP = 5
    for epoch in range(args.epochs):
        model.train()
        for i, ((x, y), (x_phi, _)) in enumerate(zip(args.train_set, args.train_set)):
            x, y = (
                x.to(args.device, non_blocking=True),
                y.to(args.device, non_blocking=True),
            )
            x_phi = x_phi.to(args.device, non_blocking=True)

            yhat = model(x)
            yhat, temp = yhat[:, :-1], yhat[:, -1:]

            loss = criterion(torch.log_softmax(yhat, dim=1), y)

            # if epoch > 10:
            # if epoch > args.epochs // 2:
            yhat_phi, _, _, dist, _, _ = model.phi(x, perturber, theta=True)
            yhat_phi, temp_phi = yhat_phi[:, :-1], yhat_phi[:, -1]

            # kl divergence
            d, _ = dist.min(dim=0)
            # d = torch.diag(dist)
            w = 1 - torch.exp(-d.detach() / (2 * args.ls ** 2))  # type: ignore
            # w_normed = (w - w.min()) / (w.max() - w.min())
            # this is negative because we want to maximize negative entropy of OOD points

            loss_phi = (((TEMP ** w) - torch.exp(temp_phi)) ** 2).mean()

            # if i % 500 == 0:
            #     print(
            #         f"loss: {loss.item():.4f}, phi loss: {loss_phi.sum().item():.4f} "
            #         f"entropy: {ent.min()} {ent.max()} "
            #         f"temp: min: {temp.min():.4f}, max: {temp.max():.4f}, "
            #         # f"dist min: {dist.min():.4f} max: {dist.max():.4f} "
            #         # f"tmp: min: {temp.min():.4f} max: {temp.max():.4f} "
            #         f"tmp phi: min: {temp_phi.min():.4f} max: {temp_phi.max():.4f} "
            #         f"w: min {w.min():.4f} max {w.max():.4f} "
            #         f"d: {d.min():.4f} {d.max():.4f} "
            #     )

            loss += loss_phi.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            yhat_phi, h, subsets_idx, dist, plt_x, plt_x_prime = model.phi(
                x_phi, perturber, theta=False
            )

            # if i % 500 == 0:
            #     # fmt: off
            #     idx = y.repeat(10, 1) == torch.linspace(0, 9, 10, device=args.device).unsqueeze(1)

            #     # print(idx.size())
            #     for v in idx:
            #         nat_x = plt_x[v]
            #         nat_x_other = plt_x[v != 1]

            #         # pairwise natural distance between latent vectors fo the same class
            #         nat_x_dist = ((nat_x.unsqueeze(0) - nat_x.unsqueeze(1)) ** 2).sum(dim=2)
            #         # pairwise natural distance between latent vectors of one class to all others
            #         nat_x_other_dist = ((nat_x.unsqueeze(0) - nat_x_other.unsqueeze(1)) ** 2).sum(dim=2)

            #         pert_x = plt_x_prime[v]
            #         pert_x_other = plt_x_prime[v != 1]

            #         # pairwise perturbed distance between a single class
            #         pert_x_dist = ((pert_x.unsqueeze(0) - pert_x.unsqueeze(1)) ** 2).sum(dim=2)
            #         # pariwise perturbed difference between one class to all others
            #         pert_x_to_other_dist = ((pert_x.unsqueeze(0) - pert_x_other.unsqueeze(1)) ** 2).sum(dim=2)

            #         print(f"nat: {nat_x_dist.mean()}, nat_x_other: {nat_x_other_dist.mean()} pert: {pert_x_dist.mean()} pert x to other classes: {pert_x_to_other_dist.mean()}")
            #     # fmt: on

            if plot_x.size(0) < 200 and epoch % 10 == 0:
                plot_x = torch.cat((plot_x, plt_x.cpu().detach()), dim=0)
                plot_x_prime = torch.cat(
                    (plot_x_prime, plt_x_prime.cpu().detach()), dim=0
                )

            entropy = softmax_entropy(yhat_phi).mean()
            dist = dist[subsets_idx[0], subsets_idx[1]]

            # 1. entropy --> push generator to find where the model predicts low entropy
            # 2. dist --> minimize distance between set generated points and natural points
            # 3. increase entropy distribution that samples perturbed points (to increase diveristy)

            # print(f"entropy: {entropy:.4f} dist: {dist.mean()} h: {h}")
            loss = entropy + dist.mean() - h

            phi_opt.zero_grad()
            loss.backward()
            phi_opt.step()

        if epoch % 10 == 0 or epoch == args.epochs - 1:
            dropout_eval(model)
            t = ClassificationStatsTemp(
                args.test_set.dataset, args, 10, length=100 * args.batch_size * 10,
            )
            # t = ClassificationStats(args.test_set.dataset, args, args.y_dim)
            with torch.no_grad():
                for i, (x, y) in enumerate(args.test_set):
                    x, y = x.to(args.device), y.to(args.device)

                    _, _, _, _, plt_test, _ = model.phi(x, perturber)
                    if plot_test.size(0) < 200:
                        plot_test = torch.cat(
                            (plot_test, plt_test.cpu().detach()), dim=0
                        )

                    mus = model.mc(x, args.samples).mean(dim=0)
                    t.set_multiclass(y, mus)

                    if i == 99:
                        break

            t.calc_stats()
            print(
                f"epoch: {epoch} val acc: {t.accuracy:.4f} cal: {t.cal_error:.4f}, nll: {t.nll:.4f}"
            )
            if t.nll < best_stats.nll:
                best_stats = t

            print(f"starting TSNE for epoch: {epoch}")
            data_2d = TSNE().fit_transform(
                torch.cat((plot_x, plot_x_prime, plot_test[:100]), dim=0)
            )
            fig, ax = plt.subplots()

            path = os.path.join(
                "charts",
                "mc-pad-latent",
                f"{args.train_set.dataset.name}",  # type: ignore
                f"lambda-{args.kl_lambda}",
                f"ls-{args.ls}",
            )
            if not os.path.exists(path):
                os.makedirs(path)

            x_end = plot_x.size(0)
            x_prime_end = x_end + plot_x_prime.size(0)
            ax.scatter(data_2d[:x_end, 0], data_2d[:x_end, 1], label="x")
            ax.scatter(
                data_2d[x_end:x_prime_end, 0], data_2d[x_end:x_prime_end, 1], label="x'"
            )
            ax.scatter(
                data_2d[x_prime_end:, 0], data_2d[x_prime_end:, 1], label="x test"
            )
            ax.legend()
            ax.set_title(f"{args.train_set.dataset.name} epoch: {epoch}")  # type: ignore

            fig.savefig(os.path.join(path, f"latent-tsne-epoch-{epoch}.pdf"))

            # reset these to zero for the next epoch
            plot_x, plot_x_prime, plot_test = (
                torch.Tensor(),
                torch.Tensor(),
                torch.Tensor(),
            )

    # fmt: off
    models[os.path.join(f"{args.get_model_dir(run, classification=True)}", "model.pt")] = (copy.deepcopy(model).cpu().state_dict())
    models[os.path.join(f"{args.get_model_dir(run, classification=True)}", "perturber.pt")] = (copy.deepcopy(perturber).cpu().state_dict())
    # fmt: on

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> ClassificationStats:
    if args.test_set is None:
        raise ValueError("train r test set cannot be None")

    model = MCDropoutHeteroConv(
        (args.x_dim, args.x_dim), args.h_dim, args.ps, args.y_dim, get_conv_args(args)
    ).to(args.device)
    model.load_state_dict(
        models[
            os.path.join(f"{args.get_model_dir(run, classification=True)}", "model.pt")
        ]
    )

    # run the testing loop
    dropout_eval(model)
    t = ClassificationStatsTemp(args.test_set.dataset, args, 10)
    with torch.no_grad():
        for i, (x, y) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            mus = model.mc(x, args.samples).mean(dim=0)
            t.set_multiclass(y, mus)

    t.calc_stats()
    return t
