import os

import gpytorch  # type: ignore
import numpy as np  # type: ignore
import torch
from gpytorch.likelihoods import GaussianLikelihood  # type: ignore
from gpytorch.mlls import ExactMarginalLogLikelihood  # type: ignore
from torch.distributions import Normal
from tqdm import tqdm  # type: ignore

from models.gp import ExactGP
from utils import Args, ModelSaveDict, Stats

gpytorch.settings.fast_pred_var(state=False)


def train(args: Args, run: int) -> ModelSaveDict:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    likelihood = GaussianLikelihood().double()
    model = ExactGP(args.train_set.dataset.x, args.train_set.dataset.y, likelihood).to(args.device).double()  # type: ignore

    mll = ExactMarginalLogLikelihood(likelihood, model)
    optimizer = torch.optim.Adam([{"params": model.parameters()}], lr=1e-1)
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore

    loss_log = tqdm(total=0, bar_format="train: {desc}", position=2, leave=False)
    for i in range(50):
        model.train()
        likelihood.train()

        x, y = (
            args.train_set.dataset.x.to(args.device).double(),  # type: ignore
            args.train_set.dataset.y.to(args.device).double(),  # type: ignore
        )

        output = model(x)
        loss = -mll(output, y)

        desc = (
            f"iter: {i} mll: {-loss.item() - np.log(y_sigma)} "
            f"ls: {model.covar_module.base_kernel.lengthscale.item()} "
            f"noise: {model.likelihood.noise.item()}"
        )
        loss_log.set_description(desc)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    models: ModelSaveDict = {
        os.path.join(args.get_model_dir(run), "model.pt"): model.cpu().state_dict()
    }

    return models


def eval(args: Args, run: int, models: ModelSaveDict) -> Stats:
    if args.train_set is None or args.test_set is None:
        raise ValueError("train or test set cannot be None")

    t = Stats(args.test_set.dataset, args)

    likelihood = GaussianLikelihood().double()
    model = ExactGP(args.train_set.dataset.x, args.train_set.dataset.y, likelihood).to(args.device).double()  # type: ignore
    model.load_state_dict(models[os.path.join(args.get_model_dir(run), "model.pt")])
    y_mu, y_sigma = args.test_set.dataset.get_y_moments()  # type: ignore

    model.eval()
    likelihood.eval()
    with torch.no_grad():
        x, y = (
            args.test_set.dataset.x.to(args.device).double(),  # type: ignore
            args.test_set.dataset.y.to(args.device).double(),  # type: ignore
        )

        with gpytorch.settings.fast_pred_var():
            yhat = likelihood(model(x))
            mu, sigma = yhat.mean, torch.sqrt(yhat.variance)
            ll = Normal(mu, sigma).log_prob(y).sum().item()

        # scale back to input space
        mu = mu * y_sigma + y_mu
        sigma = sigma * y_sigma
        y = y * y_sigma + y_mu

        t.set_regression_stats(mu, sigma, y, ll)
        t.set_history(mu, sigma, y)

    t.normalize_stats(len(args.test_set.dataset))
    t.ll -= np.log(y_sigma)

    return t
