# System/Library imports
from typing import *
import time

# Common data science imports
from omegaconf import OmegaConf
import torch
from tqdm import tqdm

try:
    import wandb
except:
    pass

# GPytorch
import gpytorch
from gpytorch.constraints import GreaterThan

# Our imports
from gp.skip.model import SKIP
from gp.util import flatten_dict, flatten_dataset, split_dataset, filter_param


# =============================================================================
# Train / Eval
# =============================================================================

def train_gp(config, train_dataset, test_dataset):
    # Unpack dataset
    dataset_name = config.dataset.name

    # Unpack model configuration
    use_ard, use_scale, dtype, device, noise, noise_constraint, learn_noise, grid_size = (
        config.model.use_ard,
        config.model.use_scale,
        getattr(torch, config.model.dtype),
        config.model.device,
        config.model.noise,
        config.model.noise_constraint,
        config.model.learn_noise,
        config.model.grid_size,
    )

    cg_iter, cg_tol, cg_eval_tol, lanc_iter, pre_size = (
        config.model.cg.iter,
        config.model.cg.tol,
        config.model.cg.eval_tol,
        config.model.lanc_iter,
        config.model.pre_size,
    )

    # Unpack training configuration
    seed, epochs, lr = (
        config.training.seed,
        config.training.epochs,
        config.training.learning_rate,
    )

    # Set wandb
    if config.wandb.watch:
        # Create wandb config with training/model config
        config_dict = flatten_dict(OmegaConf.to_container(config, resolve=True))

        # Create name
        rname = f"skip_{dataset_name}_{noise}_{seed}"
        
        # Initialize wandb
        wandb.init(
            project=config.wandb.project,
            entity=config.wandb.entity,
            group=config.wandb.group,
            name=rname,
            config=config_dict
        )
    
    print("Setting dtype to ...", dtype)
    torch.set_default_dtype(dtype)

    # Dataset preparation
    train_x, train_y = flatten_dataset(train_dataset)
    train_x = train_x.to(dtype=dtype, device=device).contiguous()
    train_y = train_y.to(dtype=dtype, device=device).contiguous()
    # print("HERE", train_x.shape, train_y.shape)

    test_x, test_y = flatten_dataset(test_dataset)
    test_x = test_x.to(dtype=dtype, device=device).contiguous()
    test_y = test_y.to(dtype=dtype, device=device).contiguous()

    # Model
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=GreaterThan(noise_constraint)).to(device=device)
    likelihood.noise = torch.tensor([noise]).to(device=device)
    model = SKIP(train_x, train_y, likelihood, grid_size=grid_size, use_ard=use_ard).to(device=device)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    # Training parameters
    model.train()
    likelihood.train()

    # Set optimizer
    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam([{'params': params}], lr=lr)
    lr_sched = lambda epoch: 1.0
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_sched)
    
    # Training loop
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    for epoch in pbar:
        t1 = time.perf_counter()

        # Load batch
        optimizer.zero_grad()
        with gpytorch.settings.cg_tolerance(cg_tol), \
            gpytorch.settings.use_toeplitz(False), \
            gpytorch.settings.max_preconditioner_size(pre_size), \
            gpytorch.settings.max_root_decomposition_size(lanc_iter):
            output = model(train_x)
            loss = -mll(output, train_y)

            loss.backward()
        # step optimizers and learning rate schedulers
        optimizer.step()
        scheduler.step()
        torch.cuda.synchronize()
        t2 = time.perf_counter()

        # Log
        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix(MLL=f"{-loss.item()}")

        # Evaluate
        test_rmse, test_nll = eval_gp(model, likelihood, test_x, test_y, device=device, lanc_iter=lanc_iter, pre_size=pre_size, cg_tol=cg_eval_tol)
        model.train()

        if config.wandb.watch:
            results = {
                "loss": loss,
                "test_nll": test_nll,
                "test_rmse": test_rmse,
                "epoch_time": t2 - t1,
                "noise": model.get_noise(),
                "lengthscale": wandb.Histogram(model.get_lengthscale().detach().cpu().numpy()),
                "outputscale": model.get_outputscale(),
            }

            wandb.log(results)
        
    return model, model.likelihood


def eval_gp(model, likelihood, x, y, device="cuda:0", dtype=torch.float32, lanc_iter=100, pre_size=100, cg_tol=1e-2):
    model.eval()

    with gpytorch.settings.eval_cg_tolerance(cg_tol), \
        gpytorch.settings.use_toeplitz(False), \
        gpytorch.settings.max_preconditioner_size(pre_size), \
        gpytorch.settings.max_root_decomposition_size(lanc_iter), \
        gpytorch.settings.fast_pred_var(), torch.no_grad():
        # pred_y = model.likelihood(model(x))
        pred_y = model(x)

        # rmse = (pred_y - y).pow(2).mean(0).sqrt()
        # nll = 0
        rmse = (pred_y.mean - y).pow(2).mean(0).sqrt()
        # mae = (pred_y.mean - y).abs().mean(0)
        nll = - torch.distributions.Normal(
            pred_y.mean,
            pred_y.variance.add(model.likelihood.noise).sqrt()
        ).log_prob(y).mean()

    # # Set into eval mode
    # model.eval()
    # likelihood.eval()

    # # Testing loop
    # test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
    # squared_errors = []
    # nlls = []
    # with gpytorch.settings.max_preconditioner_size(10), torch.no_grad(), gpytorch.settings.max_root_decomposition_size(30), gpytorch.settings.fast_pred_var():
    #     for test_x, test_y in tqdm(test_loader):
    #         test_x = test_x.to(dtype=dtype, device=device).contiguous()
    #         test_y = test_y.to(dtype=dtype, device=device).contiguous()
    #         pred_y = model(test_x)
    #         nll = - torch.distributions.Normal(
    #             pred_y.mean,
    #             pred_y.variance.add(model.likelihood.noise).sqrt()).log_prob(test_y).mean()
    #         se = torch.sum((pred_y.mean - test_y)**2)
    #         squared_errors += [se]
    #         nlls += [nll]
    #     rmse = torch.sqrt(torch.sum(torch.tensor(squared_errors)) / len(test_dataset))
    #     nll = torch.cat(nlls).mean()

    print("RMSE", rmse, rmse.dtype, "NLL", nll, "NOISE", model.get_noise().item(), "LENGTHSCALE", model.get_lengthscale(), "OUTPUTSCALE", model.get_outputscale())
    return rmse, nll
