import os

import numpy as np  # type: ignore
import torch
from torch.distributions import Normal

from models.deep_ensemble import Model
from models.swag import SWAG
from utils import Args, ModelSaveDict, Stats, get_mixture_mu_var


def schedule(epoch: int, args: Args) -> float:
    t = epoch / args.swag_start
    lr_ratio = args.swag_lr / args.swag_lr_init
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.swag_lr_init * factor


def adjust_learning_rate(optimizer: torch.optim.SGD, lr: float) -> float:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test cannot be none")

    model_args = [args.x_dim, args.h_dim, args.y_dim]
    model = Model(*model_args).to(args.device)

    swag_model = SWAG(
        Model,
        args.device,
        no_cov_mat=False,
        max_num_models=args.model_n,
        model_args=model_args,
    ).to(args.device)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.swag_lr_init,
        momentum=args.swag_momentum,
        weight_decay=args.weight_decay,
    )

    args.swag_start = args.epochs // 2

    for epoch in range(args.epochs):
        lr = schedule(epoch, args)
        adjust_learning_rate(optimizer, lr)
        for i, (x, y, idx) in enumerate(args.train_set):
            x, y = x.to(args.device), y.to(args.device)
            idx = idx.to(args.device)

            mu, logvar = model(x)
            logvar = torch.clamp(logvar, args.logvar_min, args.logvar_max)

            loss = (torch.exp(-logvar) * (y - mu) ** 2 + logvar).mean()

            if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                raise OverflowError(
                    f"got nans in training:\nmu:{mu}\nlogvar:{logvar}\nargs: {args}"
                )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch > args.swag_start:
            # maybe this doesn't need to be done on every epoch (it didn't in the original)
            swag_model.collect_model(model)
            if epoch == 0 or epoch % 5 == 0 or epoch == args.epochs - 1:
                swag_model.sample(0.0)

    models: ModelSaveDict = {
        os.path.join(
            args.get_model_dir(run), "model.pt",
        ): swag_model.cpu().state_dict(),
    }

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test cannot be none")

    model_args = [args.x_dim, args.h_dim, args.y_dim]
    model = Model(*model_args).to(args.device)
    swag_model = SWAG(
        Model,
        args.device,
        no_cov_mat=False,
        max_num_models=args.model_n,
        model_args=model_args,
    ).to(args.device)

    swag_model.load_state_dict(
        models[os.path.join(args.get_model_dir(run), "model.pt",)]
    )

    t = Stats(args.test_set.dataset, args)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore
    model.eval()
    with torch.no_grad():
        for i, (x, y, _) in enumerate(args.test_set):
            x, y = x.to(args.device), y.to(args.device)

            mus, logvars = swag_model.mc(x, args.samples)
            logvars = torch.clamp(logvars, args.logvar_min, args.logvar_max)

            mu, var = get_mixture_mu_var(mus, logvars)
            sigma = torch.sqrt(var)

            # calculate log likelihood before denormalizing, can cause nan otherwise
            log_lik = Normal(mu, sigma).log_prob(y).sum().item()

            # denormalize the data
            mu = mu * y_sigma + y_mu
            sigma *= y_sigma
            y = y * y_sigma + y_mu

            t.set_history(mu, sigma, y)
            t.set_regression_stats(mu, sigma, y, log_lik)

        t.normalize_stats(len(args.test_set.dataset))
        t.ll -= np.log(y_sigma)  # brings std normal space back to input

    return t
