import os
from typing import Tuple

import numpy as np  # type: ignore
import torch
from torch.distributions import Normal, kl_divergence
from tqdm import tqdm  # type: ignore

from models.deep_ensemble import Model
from models.set_encoder import SetEncoder
from utils import Args, ModelSaveDict, Stats


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = [
        Model(args.x_dim, args.h_dim, args.y_dim).to(args.device)
        for _ in range(args.model_n)
    ]
    params = list(models[0].parameters())
    for m in models[1:]:
        params += list(m.parameters())

    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)

    gen = SetEncoder(args.x_dim, args.h_dim, args.x_dim).to(args.device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=args.beta_lr)

    exp_norm = torch.ones(args.x_dim).norm().float() * args.norm_c

    for epoch in tqdm(range(args.epochs)):
        [m.train() for m in models]
        for i, data in enumerate(
            zip(*[args.train_set for _ in range(args.model_n * 2)])
        ):
            xs = torch.stack([x.to(args.device) for (x, _, _) in data[: args.model_n]])
            ys = torch.stack([y.to(args.device) for (_, y, _) in data[: args.model_n]])

            x_phis = torch.stack(
                [x.to(args.device) for (x, _, _) in data[args.model_n :]]
            )

            # generate, calculate distance from natural data
            n, b, ft = xs.size()
            if epoch > args.epochs // 2:
                gen_mu, gen_logvar, subsets = gen(xs.view(-1, ft))
                x_gen = Normal(gen_mu, torch.exp(gen_logvar / 2)).rsample()

                xs = torch.cat((xs, x_gen.view(n, b, ft)), dim=1)
            xs.requires_grad_(True)

            yhats = [m(x) for (m, x) in zip(models, xs)]
            mus = torch.stack([v[0] for v in yhats])
            logvars = torch.stack(
                [torch.clamp(v[1], args.logvar_min, args.logvar_max) for v in yhats]
            )

            mu_nat, logvar_nat = mus[:, :b], logvars[:, :b]
            mu_gen, logvar_gen = mus[:, b:], logvars[:, b:]

            loss = (torch.exp(-logvar_nat) * (ys - mu_nat) ** 2 + logvar_nat).mean()

            if epoch > args.epochs // 2:
                d, _ = (
                    (
                        (
                            xs[:, b:].view(n, b, ft).unsqueeze(2)
                            - xs[:, :b].view(n, b, ft).unsqueeze(1)
                        )
                        ** 2
                    )
                    .sum(dim=3)
                    .min(dim=2)
                )

                # print(f"min nat dist: {min_nat_dist}")
                w = 1 - torch.exp(-d / (2 * args.ls ** 2))

                # kl divergence
                kl = kl_divergence(
                    Normal(mu_gen, torch.exp(logvar_gen / 2)), Normal(mu_gen, 1.0)
                )
                loss += (w * kl).sum() / 5

            # optimize phase one
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # PHASE TWO: generate, forward, minimize generatro obbjective
            gen_mu, gen_logvar, subsets = gen(x_phis.view(-1, x_phis.size(2)))
            x_gen = Normal(gen_mu, torch.exp(gen_logvar / 2)).rsample()

            # forward pass with the new samples
            yhats = [m(x_gen) for (m, x) in zip(models, x_gen.view(x_phis.size()))]
            mus = torch.stack([v[0] for v in yhats])
            logvars = torch.stack(
                [torch.clamp(v[1], args.logvar_min, args.logvar_max) for v in yhats]
            )

            gen_entropy = Normal(gen_mu, torch.exp(gen_logvar / 2)).entropy().mean()
            entropy = Normal(mus, torch.exp(logvars / 2)).entropy().mean()

            # the loss here should be towards low entropy, away from the data
            nat_dist = (
                ((x_gen.unsqueeze(1) - x_phis.view(-1, x_phis.size(2))[subsets]) ** 2)
                .sum(dim=2)
                .mean(dim=1)
            )

            nat_dist = nat_dist[nat_dist > (exp_norm * 1)]
            if nat_dist.size(0) == 0:
                nat_dist = torch.tensor(0.0)

            loss = entropy - gen_entropy + nat_dist.mean()

            if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                raise OverflowError(
                    f"x gen: {x_gen} entropy: {entropy} nat_dist:{nat_dist} "
                    f"entropy: {entropy} nat_dist:{nat_dist.mean()}"
                )

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

    model_dict: ModelSaveDict = {}
    for i, m in enumerate(models):
        path = os.path.join(args.get_model_dir(run), f"model-{i}.pt")
        model_dict[path] = m.cpu().state_dict()
    model_dict[
        os.path.join(args.get_model_dir(run), "generator.pt")
    ] = gen.cpu().state_dict()

    return model_dict


def eval(args: Args, run: int, model_dict: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = []
    for i in range(args.model_n):
        m = Model(args.x_dim, args.h_dim, args.y_dim)
        path = os.path.join(args.get_model_dir(run), f"model-{i}.pt")
        m.load_state_dict(model_dict[path])
        m.eval()
        models.append(m.to(args.device))

    t = Stats(args.test_set.dataset, args)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore
    with torch.no_grad():
        for i, (x, y, _) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            yhats = [m(x) for m in models]
            mus = torch.stack([v[0] for v in yhats])
            logvars = torch.stack([v[1] for v in yhats])
            logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            # Gaussian mixture
            mu = mus.mean(dim=0)
            var = ((torch.exp(logvars)) + (mus ** 2)).mean(dim=0) - (mu ** 2)
            sigma = torch.sqrt(var)

            # get log likelihood in std normal space, can cause nan otherwise
            ll = Normal(mu, sigma).log_prob(y).sum().item()

            # scale back to input space
            mu = mu * y_sigma + y_mu
            sigma = sigma * y_sigma
            y = y * y_sigma + y_mu

            t.set_regression_stats(mu, sigma, y, ll)
            t.set_history(mu, sigma, y)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)

    return t
