import os
from typing import Tuple

import numpy as np  # type: ignore
import torch
from torch.distributions import Normal
from tqdm import tqdm  # type: ignore

from models.deep_ensemble import Model
from utils import Args, ModelSaveDict, Stats


def make_adv(
    x: torch.Tensor,
    y: torch.Tensor,
    mu: torch.Tensor,
    logvar: torch.Tensor,
    ft_ranges: torch.Tensor,
) -> torch.Tensor:
    var = torch.exp(logvar)
    loss = ((y - mu.detach()) ** 2 / var + logvar).mean()
    grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
    return x + (0.01 * ft_ranges) * torch.sign(grad)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = [
        Model(args.x_dim, args.h_dim, args.y_dim).to(args.device)
        for _ in range(args.model_n)
    ]
    params = list(models[0].parameters())
    for m in models[1:]:
        params += list(m.parameters())

    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    ft_ranges = args.train_set.dataset.get_feature_ranges().to(args.device)  # type: ignore

    for epoch in tqdm(range(args.epochs)):
        [m.train() for m in models]
        for i, data in enumerate(zip(*[args.train_set for _ in range(args.model_n)])):
            xs = torch.stack([x.to(args.device) for (x, _, _) in data])
            ys = torch.stack([y.to(args.device) for (_, y, _) in data])
            # idx = torch.stack([idx.to(args.device) for (_, _, idx) in data])
            xs.requires_grad_(True)

            yhats = [m(x) for (m, x) in zip(models, xs)]
            mus = torch.stack([v[0] for v in yhats])
            logvars = torch.stack(
                [torch.clamp(v[1], args.logvar_min, args.logvar_max) for v in yhats]
            )

            loss = (torch.exp(-logvars) * (ys - mus) ** 2 + logvars).mean()

            x_adv = make_adv(xs, ys, mus, logvars, ft_ranges)

            yhats_adv = [m(x) for (m, x) in zip(models, x_adv)]
            mus_adv = torch.stack([v[0] for v in yhats_adv])
            logvars_adv = torch.stack([v[1] for v in yhats_adv])

            loss += (torch.exp(-logvars_adv) * (ys - mus_adv) ** 2 + logvars_adv).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    model_dict: ModelSaveDict = {}
    for i, m in enumerate(models):
        path = os.path.join(args.get_model_dir(run), f"model-{i}.pt")
        model_dict[path] = m.cpu().state_dict()

    return model_dict


def eval(args: Args, run: int, model_dict: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    models = []
    for i in range(args.model_n):
        m = Model(args.x_dim, args.h_dim, args.y_dim)
        path = os.path.join(args.get_model_dir(run), f"model-{i}.pt")
        m.load_state_dict(model_dict[path])
        m.eval()
        models.append(m.to(args.device))

    t = Stats(args.test_set.dataset, args)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore
    with torch.no_grad():
        for i, (x, y, _) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            yhats = [m(x) for m in models]
            mus = torch.stack([v[0] for v in yhats])
            logvars = torch.stack([v[1] for v in yhats])
            logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            # Gaussian mixture
            mu = mus.mean(dim=0)
            var = ((torch.exp(logvars)) + (mus ** 2)).mean(dim=0) - (mu ** 2)
            sigma = torch.sqrt(var)

            # get log likelihood in std normal space, can cause nan otherwise
            ll = Normal(mu, sigma).log_prob(y).sum().item()

            # scale back to input space
            mu = mu * y_sigma + y_mu
            sigma = sigma * y_sigma
            y = y * y_sigma + y_mu

            t.set_regression_stats(mu, sigma, y, ll)
            t.set_history(mu, sigma, y)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)

    return t
