# System/Library imports
from typing import *
import time

# Common data science imports
from omegaconf import OmegaConf
import torch
from tqdm import tqdm

try:
    import wandb
except:
    pass

# GPytorch
import gpytorch
from gpytorch.constraints import GreaterThan

# Our imports
from gp.ski.model import SKIGPModel
from gp.util import flatten_dict, flatten_dataset, split_dataset, filter_param


# =============================================================================
# Train / Eval
# =============================================================================

def train_gp(config, train_dataset, test_dataset):
    # Unpack dataset
    dataset_name = config.dataset.name

    # Unpack model configuration
    use_ard, use_scale, num_inducing, dtype, device, noise, noise_constraint, learn_noise, grid_size = (
        config.model.use_ard,
        config.model.use_scale,
        config.model.num_inducing,
        getattr(torch, config.model.dtype),
        config.model.device,
        config.model.noise,
        config.model.noise_constraint,
        config.model.learn_noise,
        config.model.grid_size,
    )
    use_ard = False

    # Unpack training configuration
    seed, epochs, lr = (
        config.training.seed,
        config.training.epochs,
        config.training.learning_rate,
    )

    # Set wandb
    if config.wandb.watch:
        # Create wandb config with training/model config
        config_dict = flatten_dict(OmegaConf.to_container(config, resolve=True))

        # Create name
        rname = f"ski_{dataset_name}_{num_inducing}_{noise}_{seed}"
        
        # Initialize wandb
        wandb.init(
            project=config.wandb.project,
            entity=config.wandb.entity,
            group=config.wandb.group,
            name=rname,
            config=config_dict
        )
    
    print("Setting dtype to ...", dtype)
    torch.set_default_dtype(dtype)

    # Dataset preparation
    train_x, train_y = flatten_dataset(train_dataset)
    train_x = train_x.to(dtype=dtype, device=device).contiguous()
    train_y = train_y.to(dtype=dtype, device=device).contiguous()
    print("HERE", train_x.shape, train_y.shape)

    test_x, test_y = flatten_dataset(test_dataset)
    test_x = test_x.to(dtype=dtype, device=device).contiguous()
    test_y = test_y.to(dtype=dtype, device=device).contiguous()

    # print("RELOADIN SNELSON")
    # import pandas as pd
    # df = pd.read_csv(f'/home/dehuang/soft-gp/data/snelson.csv')
    # train_x = torch.from_numpy(df.x.values[:, np.newaxis]).float().to(device).contiguous()
    # train_y = torch.from_numpy(df.y.values).float().to(device).contiguous()
    # print(train_x.shape, train_y.shape)
    # train_x = train_x.repeat(1, 2).contiguous()
    # print(train_x.shape, train_y.shape)


    # Model
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=GreaterThan(noise_constraint)).to(device=device)
    likelihood.noise = torch.tensor([noise]).to(device=device)
    model = SKIGPModel(train_x, train_y, likelihood, grid_size=grid_size, use_ard=use_ard).to(device=device)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    # Training parameters
    model.train()
    likelihood.train()

    # Set optimizer
    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam([{'params': params}], lr=lr)
    lr_sched = lambda epoch: 1.0
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_sched)
    
    # Training loop
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    for epoch in pbar:
        t1 = time.perf_counter()

        # Load batch
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)

        loss.backward()
        # step optimizers and learning rate schedulers
        optimizer.step()
        scheduler.step()
        torch.cuda.synchronize()
        t2 = time.perf_counter()

        # Log
        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix(MLL=f"{-loss.item()}")

        # Evaluate
        test_rmse, test_nll = eval_gp(model, test_x, test_y, device=device)
        model.train()

        if config.wandb.watch:
            results = {
                "loss": loss,
                "test_nll": test_nll,
                "test_rmse": test_rmse,
                "epoch_time": t2 - t1,
                "noise": model.get_noise(),
                "lengthscale": wandb.Histogram(model.get_lengthscale().detach().cpu().numpy()),
                "outputscale": model.get_outputscale(),
            }

            wandb.log(results)
        
    return model, model.likelihood


def eval_gp(model, x, y, device="cuda:0", dtype=torch.float32):
    model.eval()

    with gpytorch.settings.eval_cg_tolerance(1e-2), \
       gpytorch.settings.max_cg_iterations(1000), \
       gpytorch.settings.max_preconditioner_size(100), \
       gpytorch.settings.max_root_decomposition_size(100), \
       gpytorch.settings.fast_pred_var(), torch.no_grad():

        # pred_y = model.likelihood(model(x))
        pred_y = model(x)

        rmse = (pred_y - y).pow(2).mean(0).sqrt()
        nll = 0
        # rmse = (pred_y.mean - y).pow(2).mean(0).sqrt()
        # mae = (pred_y.mean - y).abs().mean(0)
        # nll = - torch.distributions.Normal(
        #     pred_y.mean,
        #     pred_y.variance.add(model.likelihood.noise).sqrt()
        # ).log_prob(y).mean()

    # # Set into eval mode
    # model.eval()

    # # Testing loop
    # test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
    # squared_errors = []
    # nlls = []
    # print("EVALUATING")
    # with gpytorch.settings.eval_cg_tolerance(1e-2), \
    #    gpytorch.settings.max_cg_iterations(1000), \
    #    gpytorch.settings.max_preconditioner_size(100), \
    #    gpytorch.settings.max_root_decomposition_size(100), \
    #    gpytorch.settings.fast_pred_var(), torch.no_grad():
    #     for test_x, test_y in tqdm(test_loader):
    #         test_x = test_x.to(dtype=dtype, device=device).contiguous()
    #         test_y = test_y.to(dtype=dtype, device=device).contiguous()
    #         print("HERE 1")
    #         pred_y = model(test_x)
    #         print("HERE 2")
    #         nll = - torch.distributions.Normal(
    #             pred_y.mean,
    #             pred_y.variance.add(model.likelihood.noise).sqrt()).log_prob(test_y).mean()
    #         print("HERE 3")
    #         se = torch.sum((pred_y.mean - test_y)**2)
    #         print("HERE 4")
    #         squared_errors += [se]
    #         nlls += [nll]
    #     rmse = torch.sqrt(torch.sum(torch.tensor(squared_errors)) / len(test_dataset))
    #     nll = torch.cat(nlls).mean()

    print("RMSE", rmse, rmse.dtype, "NLL", nll, "NOISE", model.get_noise().item(), "LENGTHSCALE", model.get_lengthscale(), "OUTPUTSCALE", model.get_outputscale())
    return rmse, nll
