import os

import numpy as np  # type: ignore
import torch
from torch.distributions import Normal, kl_divergence

from models.r1bnn import Model
from models.set_encoder import SetEncoder
from utils import Args, ModelSaveDict, Stats


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(args.x_dim, args.h_dim, args.y_dim, args.model_n).to(args.device)
    optimizer = torch.optim.Adam(
        [
            {
                "params": list(model.l1_shared.parameters())
                + list(model.l2_shared.parameters())
                + list(model.mu_shared.parameters())
                + list(model.logvar_shared.parameters()),
                "lr": args.lr,
                "weight_decay": args.weight_decay,
            },
            {
                "params": list(model.l1_s.parameters())
                + list(model.l1_r.parameters())
                + list(model.l2_s.parameters())
                + list(model.l2_r.parameters())
                + list(model.mu_s.parameters())
                + list(model.mu_r.parameters())
                + list(model.logvar_s.parameters())
                + list(model.logvar_r.parameters()),
                "lr": args.lr,
            },
        ]
    )

    gen = SetEncoder(args.x_dim, args.h_dim, args.x_dim).to(args.device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=args.beta_lr)

    exp_norm = torch.ones(args.x_dim).norm().float() * args.norm_c

    for epoch in range(args.epochs):
        for i, data in enumerate(
            zip(*[args.train_set for _ in range(args.model_n * 2)])
        ):
            x = torch.stack([x.to(args.device) for (x, _, _) in data[: args.model_n]])
            y = torch.stack([y.to(args.device) for (_, y, _) in data[: args.model_n]])

            x_phi = torch.stack(
                [x.to(args.device) for (x, _, _) in data[args.model_n :]]
            )

            # generate, calculate distance from natural data
            n, b, ft = x.size()
            if epoch > args.epochs // 2:
                gen_mu, gen_logvar, subsets = gen(x.view(-1, x.size(2)))
                x_gen = Normal(gen_mu, torch.exp(gen_logvar / 2)).rsample()

                x = torch.cat((x, x_gen.view(x.size())), dim=1)

            # forward on both natural and generated
            mu, logvar = model(x)
            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)

            mu_nat, logvar_nat = mu[:, :b], logvar[:, :b]
            mu_gen, logvar_gen = mu[:, b:], logvar[:, b:]

            loss = (
                torch.exp(-logvar_nat) * (y - mu_nat) ** 2 + logvar_nat
            ).mean() + args.kl_beta * model.kl()

            # only optimize theta for the latter epochs
            if epoch > args.epochs // 2:
                # stop gradient from going to theta
                d, _ = (
                    (
                        (
                            x[:, b:].view(n, b, ft).unsqueeze(2).detach()
                            - x[:, :b].view(n, b, ft).unsqueeze(1)
                        )
                        ** 2
                    )
                    .sum(dim=3)
                    .min(dim=2)
                )

                # kl divergence
                w = 1 - torch.exp(-d / (2 * args.ls ** 2))
                kl = kl_divergence(
                    Normal(mu_gen, torch.exp(logvar_gen / 2)), Normal(mu_gen, 1.0)
                )
                loss += (w * kl).sum() / 5

            # optimize phase one
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            gen_mu, gen_logvar, subsets = gen(x_phi.view(-1, ft))
            x_gen = Normal(gen_mu, torch.exp(gen_logvar / 2)).rsample()

            # forward on both natural and generated
            mu, logvar = model(x_gen)
            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)

            gen_entropy = Normal(gen_mu, torch.exp(gen_logvar / 2)).entropy().mean()
            entropy = Normal(mu, torch.exp(logvar / 2)).entropy().mean()

            # the loss here should be towards low entropy, away from the data

            nat_dist = (
                ((x_gen.unsqueeze(1) - x_phi.view(-1, ft)[subsets]) ** 2)
                .sum(dim=2)
                .mean(dim=1)
            )

            nat_dist = nat_dist[nat_dist > (exp_norm * 1)]
            if nat_dist.size(0) == 0:
                nat_dist = torch.tensor(0.0)

            loss = entropy - gen_entropy + nat_dist.mean()

            if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                raise OverflowError(
                    f"x gen: {x_gen} entropy: {entropy} nat_dist:{nat_dist} "
                    f"entropy: {entropy} nat_dist:{nat_dist.mean()}"
                )

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

    models: ModelSaveDict = {
        os.path.join(args.get_model_dir(run), "model.pt"): model.cpu().state_dict(),
        os.path.join(args.get_model_dir(run), "generator.pt"): gen.cpu().state_dict(),
    }

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(args.x_dim, args.h_dim, args.y_dim, args.model_n).to(args.device)
    model.load_state_dict(models[os.path.join(args.get_model_dir(run), "model.pt")])

    model.eval()
    t = Stats(args.test_set.dataset, args)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore
    with torch.no_grad():
        for i, (x, y, idx) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            mus, logvars = model.mc(x, args.samples)
            logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            # Gaussian mixture
            sample_mu = mus.mean(dim=0)
            sample_var = ((torch.exp(logvars)) + (mus ** 2)).mean(dim=0) - (
                sample_mu ** 2
            )

            mu = sample_mu.mean(dim=0)
            var = (sample_var + sample_mu ** 2).mean(dim=0) - (mu ** 2)
            sigma = torch.sqrt(var)

            # calculate log likelihood before denormalizing, can cause nan otherwise
            log_lik = Normal(mu, sigma).log_prob(y).sum().item()

            # denormalize the data
            mu = mu * y_sigma + y_mu
            sigma = sigma * y_sigma
            y = y * y_sigma + y_mu

            t.set_history(mu, sigma, y)
            t.set_regression_stats(mu, sigma, y, log_lik)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)

    return t
