import os
from typing import Tuple

import numpy as np  # type: ignore
import torch
from torch.distributions import Normal
from tqdm import tqdm  # type: ignore

from models.mc_dropout import MCDropoutHetero
from utils import Args, ModelSaveDict, Stats, dropout_eval, get_mixture_mu_var


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train r test set cannot be None")

    model = MCDropoutHetero(args.x_dim, args.h_dim, args.ps, args.y_dim).to(args.device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    for epoch in tqdm(range(args.epochs), position=1, leave=False):
        for i, (x, y, idx) in enumerate(args.train_set):
            x, y = x.to(args.device), y.to(args.device)
            idx = idx.to(args.device)

            mu, logvar = model(x)
            mu = mu.squeeze()
            logvar = logvar.squeeze()

            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)
            loss = ((torch.exp(-logvar) * (y - mu) ** 2 + logvar)).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    models: ModelSaveDict = {
        os.path.join(f"{args.get_model_dir(run)}", "model.pt"): model.cpu().state_dict()
    }
    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> Stats:
    if args.test_set is None:
        raise ValueError("train r test set cannot be None")

    model = MCDropoutHetero(args.x_dim, args.h_dim, args.ps, args.y_dim).to(args.device)
    model.load_state_dict(
        models[os.path.join(f"{args.get_model_dir(run)}", "model.pt")]
    )
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore

    # run the testing loop
    dropout_eval(model)
    t = Stats(args.test_set.dataset, args)
    with torch.no_grad():
        for i, (x, y, _) in enumerate(args.test_set):
            # x, y = x.to(args.device), y.to(args.device)

            mus, logvars = model.mc(x, args.samples)

            mus = mus.squeeze(2)
            logvars = logvars.squeeze(2)
            # logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            mu, var = get_mixture_mu_var(mus, logvars, dim=0)
            sigma = torch.sqrt(var)

            # calculate log likelihood before denormalizing, can cause nan otherwise
            log_lik = Normal(mu, sigma).log_prob(y).sum().item()

            # denormalize the data
            mu = mu * y_sigma + y_mu
            sigma *= y_sigma
            y = y * y_sigma + y_mu

            t.set_history(mu, sigma, y)
            t.set_regression_stats(mu, sigma, y, log_lik)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)  # brings std normal space back to input

    return t
