import math
import os

import numpy as np  # type: ignore
import torch
from matplotlib import cm  # type: ignore
from matplotlib import pyplot as plt  # type: ignore
from sklearn.manifold import TSNE  # type: ignore
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence

from models.mc_dropout import MCDropoutHetero as Model
from models.set_encoder import SetEncoder
from utils import (Args, ModelSaveDict, Stats, dropout_eval,
                   get_mixture_mu_var, plot_generated)


def rampup(epoch: int, epochs: int) -> float:
    phase = 1.0 - (float(epoch) / epochs)  # type: ignore
    return math.exp(-5.0 * phase * phase)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test cannot be none")

    model_args = (args.x_dim, args.h_dim, args.ps, args.y_dim)
    model = Model(*model_args).to(args.device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    gen = SetEncoder(args.x_dim, args.h_dim, args.x_dim).to(args.device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=args.beta_lr)

    exp_norm = torch.ones(args.x_dim).norm().float() * args.norm_c

    # with torch.autograd.detect_anomaly():
    for epoch in range(args.epochs):
        for i, ((x, y, _), (x_phi, _, _)) in enumerate(
            zip(args.train_set, args.train_set)
        ):
            if x.size(0) < 6:
                continue

            b, ft = x.size()
            if epoch > args.epochs // 2:
                gen_mu, gen_logvar, _ = gen(x)
                x_gen = Normal(gen_mu, torch.exp(gen_logvar / 2)).rsample()
                x = torch.cat((x, x_gen), dim=0)

            # PHASE ONE: forward on model, forward on generated, minimize wasserstein
            mu, logvar = model(x)
            mu, logvar = mu.squeeze(1), logvar.squeeze(1)
            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)

            mu_nat, logvar_nat = mu[:b], logvar[:b]
            mu_gen, logvar_gen = mu[b:], logvar[b:]

            loss = (torch.exp(-logvar_nat) * (y - mu_nat) ** 2 + logvar_nat).mean()

            # forward pass with the new samples
            if epoch > args.epochs // 2:
                d, _ = (
                    ((x[b:].unsqueeze(1).detach() - x[:b].unsqueeze(0)) ** 2)
                    .sum(dim=2)
                    .min(dim=1)
                )

                # kl divergence
                w = 1 - torch.exp(-d / (2 * args.ls ** 2))
                kl = kl_divergence(
                    Normal(mu_gen, torch.exp(logvar_gen / 2)), Normal(mu_gen, 1.0),
                )
                # print(y.size(), kl.size())

                loss += (w * kl).sum()

            # optimize phase one
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # PHASE TWO: generate, forward, minimize generatro obbjective
            gen_mu, gen_logvar, subsets = gen(x_phi)
            x_gen = Normal(gen_mu, torch.exp(gen_logvar / 2)).rsample()

            # forward pass with the new samples
            mu, logvar = model(x_gen)
            mu, logvar = mu.squeeze(1), logvar.squeeze(1)
            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)

            gen_entropy = Normal(gen_mu, torch.exp(gen_logvar / 2)).entropy().mean()
            entropy = Normal(mu, torch.exp(logvar / 2)).entropy().mean()

            # the loss here should be towards low entropy, away from the data
            nat_dist = (
                ((x_gen.unsqueeze(1) - x_phi[subsets]) ** 2).sum(dim=2).mean(dim=1)
            )

            nat_dist = nat_dist[nat_dist > (exp_norm * 1)]
            if nat_dist.size(0) == 0:
                nat_dist = torch.tensor(0.0, device=args.device)

            loss = entropy + nat_dist.mean() - gen_entropy
            if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                raise OverflowError(
                    f"x gen: {x_gen} nat_dist:{nat_dist} nat_dist:{nat_dist.mean()}"
                )

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

    # plot_generated(model, gen, args)

    models: ModelSaveDict = {
        os.path.join(args.get_model_dir(run), "model.pt",): model.cpu().state_dict(),
        os.path.join(args.get_model_dir(run), "encoder.pt",): gen.cpu().state_dict(),
    }

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test cannot be none")

    model_args = (args.x_dim, args.h_dim, args.ps, args.y_dim)
    model = Model(*model_args).to(args.device)
    model.load_state_dict(models[os.path.join(args.get_model_dir(run), "model.pt",)])
    dropout_eval(model)

    # gen = SetEncoder(args.x_dim, args.h_dim, args.x_dim).to(args.device)
    # gen.load_state_dict(models[os.path.join(args.get_model_dir(run), "encoder.pt")])
    # gen.eval()

    t = Stats(args.test_set.dataset, args)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore
    with torch.no_grad():
        for i, (x, y, _) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            mus, logvars = model.mc(x, args.samples)
            mus, logvars = mus.squeeze(2), logvars.squeeze(2)
            logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            mu, var = get_mixture_mu_var(mus, logvars)
            sigma = torch.sqrt(var)

            # calculate log likelihood before denormalizing, can cause nan otherwise
            log_lik = Normal(mu, sigma).log_prob(y).sum().item()

            # denormalize the data
            mu = mu * y_sigma + y_mu
            sigma *= y_sigma
            y = y * y_sigma + y_mu

            t.set_history(mu, sigma, y)
            t.set_regression_stats(mu, sigma, y, log_lik)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)  # brings std normal space back to input

    return t
