import os

import numpy as np  # type: ignore
import torch
from torch.distributions import Normal

from models.r1bnn import Model
from utils import Args, ModelSaveDict, Stats


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(args.x_dim, args.h_dim, args.y_dim, args.model_n).to(args.device)
    optimizer = torch.optim.Adam(
        [
            {
                "params": list(model.l1_shared.parameters())
                + list(model.l2_shared.parameters())
                + list(model.mu_shared.parameters())
                + list(model.logvar_shared.parameters()),
                "lr": args.lr,
                "weight_decay": args.weight_decay,
            },
            {
                "params": list(model.l1_s.parameters())
                + list(model.l1_r.parameters())
                + list(model.l2_s.parameters())
                + list(model.l2_r.parameters())
                + list(model.mu_s.parameters())
                + list(model.mu_r.parameters())
                + list(model.logvar_s.parameters())
                + list(model.logvar_r.parameters()),
                "lr": args.lr,
            },
        ]
    )

    for epoch in range(args.epochs):
        for i, data in enumerate(zip(*[args.train_set for _ in range(args.model_n)])):
            x = torch.stack([x.to(args.device) for (x, _, _) in data])
            y = torch.stack([y.to(args.device) for (_, y, _) in data])
            # idx = torch.stack([idx.to(args.device) for (_, _, idx) in data])

            # regular loss
            mu, logvar = model(x)
            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)

            loss = (
                torch.exp(-logvar) * (y - mu) ** 2 + logvar
            ).mean() + args.kl_beta * model.kl()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    models: ModelSaveDict = {
        os.path.join(args.get_model_dir(run), "model.pt"): model.cpu().state_dict()
    }

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    model = Model(args.x_dim, args.h_dim, args.y_dim, args.model_n).to(args.device)
    model.load_state_dict(models[os.path.join(args.get_model_dir(run), "model.pt")])

    model.eval()
    t = Stats(args.test_set.dataset, args)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore
    with torch.no_grad():
        for i, (x, y, idx) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            mus, logvars = model.mc(x, args.samples)
            logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            # Gaussian mixture
            sample_mu = mus.mean(dim=0)
            sample_var = ((torch.exp(logvars)) + (mus ** 2)).mean(dim=0) - (
                sample_mu ** 2
            )

            mu = sample_mu.mean(dim=0)
            var = (sample_var + sample_mu ** 2).mean(dim=0) - (mu ** 2)
            sigma = torch.sqrt(var)

            # calculate log likelihood before denormalizing, can cause nan otherwise
            log_lik = Normal(mu, sigma).log_prob(y).sum().item()

            # denormalize the data
            mu = mu * y_sigma + y_mu
            sigma = sigma * y_sigma
            y = y * y_sigma + y_mu

            t.set_history(mu, sigma, y)
            t.set_regression_stats(mu, sigma, y, log_lik)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)

    return t
