# System/Library imports
from typing import *
import time

# Common data science imports
from omegaconf import OmegaConf
import numpy as np
from sklearn.cluster import KMeans
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    import wandb
except:
    pass

# GPytorch
import gpytorch
from gpytorch.constraints import GreaterThan
from gpytorch.kernels import RBFKernel, MaternKernel

# Our imports
from gp.simplex_gp.model import SimplexGPModel
from gp.util import dynamic_instantiation, flatten_dict, flatten_dataset, split_dataset, filter_param, heatmap


# =============================================================================
# Configuration
# =============================================================================

CONFIG = OmegaConf.create({
    'model': {
        'name': 'simplex-ski',
        'kernel': {
            '_target_': 'RBFLattice',
            'ard_num_dims': None,
        },
        'cg': {
            'iter': 500,
            'eval_tol': 1e-2,
            'tol': 1,
        },
        'lanc_iter': 100,
        'pre_size': 100,
        'use_ard': True,
        'use_scale': True,
        'noise': 0.5,
        'noise_constraint': 1e-1,
        'learn_noise': True,
        'dtype': 'float32',
        'device': 'cpu',
    },
    'dataset': {
        'name': 'elevators',
        'train_frac': 0.9,
        'val_frac': 0.0,
        'num_workers': 0,
    },
    'training': {
        'seed': 42,
        'learning_rate': 0.1,
        'epochs': 50,
    },
    'wandb': {
        'watch': True,
        'group': 'test',
        'entity': 'bogp',
        'project': 'softki3',
    }
})


# =============================================================================
# Train / Eval
# =============================================================================

def train_gp(config, train_dataset, test_dataset):
    # Unpack dataset
    dataset_name = config.dataset.name

    # Unpack model configuration
    use_ard, use_scale, dtype, device, noise, noise_constraint, learn_noise = (
        config.model.use_ard,
        config.model.use_scale,
        getattr(torch, config.model.dtype),
        config.model.device,
        config.model.noise,
        config.model.noise_constraint,
        config.model.learn_noise,
    )

    cg_iter, cg_tol, cg_eval_tol, lanc_iter, pre_size = (
        config.model.cg.iter,
        config.model.cg.tol,
        config.model.cg.eval_tol,
        config.model.lanc_iter,
        config.model.pre_size,
    )

    # Unpack training configuration
    seed, epochs, lr = (
        config.training.seed,
        config.training.epochs,
        config.training.learning_rate,
    )

    # Set wandb
    if config.wandb.watch:
        # Create wandb config with training/model config
        config_dict = flatten_dict(OmegaConf.to_container(config, resolve=True))

        # Create name
        rname = f"simplex-ski_{dataset_name}_{noise}_{seed}"
        
        # Initialize wandb
        wandb.init(
            project=config.wandb.project,
            entity=config.wandb.entity,
            group=config.wandb.group,
            name=rname,
            config=config_dict
        )
    
    print("Setting dtype to ...", dtype)
    torch.set_default_dtype(dtype)

    # Dataset preparation
    train_x, train_y = flatten_dataset(train_dataset)
    train_x = train_x.to(dtype=dtype, device=device).contiguous()
    train_y = train_y.to(dtype=dtype, device=device).contiguous()
    print("HERE", train_x.shape, train_y.shape)

    test_x, test_y = flatten_dataset(test_dataset)
    test_x = test_x.to(dtype=dtype, device=device).contiguous()
    test_y = test_y.to(dtype=dtype, device=device).contiguous()

    # print("RELOADIN SNELSON")
    # import pandas as pd
    # df = pd.read_csv(f'/home/dehuang/soft-gp/data/snelson.csv')
    # train_x = torch.from_numpy(df.x.values[:, np.newaxis]).float().to(device).contiguous()
    # train_y = torch.from_numpy(df.y.values).float().to(device).contiguous()
    # print(train_x.shape, train_y.shape)
    # train_x = train_x.repeat(1, 2).contiguous()
    # print(train_x.shape, train_y.shape)


    # Model
    # likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=GreaterThan(noise_constraint)).to(device=device)
    # likelihood.noise = torch.tensor([noise]).to(device=device)
    model = SimplexGPModel(train_x, train_y, use_ard=use_ard, use_scale=use_scale).to(device=device)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

    # Training parameters
    model.train()
    model.likelihood.train()

    # Set optimizer
    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam([{'params': params}], lr=lr)
    lr_sched = lambda epoch: 1.0
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_sched)
    
    # Training loop
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    for epoch in pbar:
        t1 = time.perf_counter()

        # Load batch
        optimizer.zero_grad()
        with gpytorch.settings.cg_tolerance(cg_tol), \
            gpytorch.settings.max_cg_iterations(cg_iter), \
            gpytorch.settings.max_preconditioner_size(pre_size), \
            gpytorch.settings.max_root_decomposition_size(lanc_iter):
            output = model(train_x)
            loss = -mll(output, train_y)

            loss.backward()
            # step optimizers and learning rate schedulers
            optimizer.step()
            scheduler.step()
        torch.cuda.synchronize()
        t2 = time.perf_counter()

        # Log
        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix(MLL=f"{-loss.item()}")

        # Evaluate
        test_rmse, test_nll = eval_gp(model, test_x, test_y, device=device, lanc_iter=lanc_iter, cg_iter=cg_iter, cg_tol=cg_eval_tol, pre_size=pre_size)
        model.train()

        if config.wandb.watch:
            results = {
                "loss": loss,
                "test_nll": test_nll,
                "test_rmse": test_rmse,
                "epoch_time": t2 - t1,
                "noise": model.get_noise(),
                "lengthscale": wandb.Histogram(model.get_lengthscale().detach().cpu().numpy()),
                "outputscale": model.get_outputscale(),
            }

            wandb.log(results)
        
    return model, model.likelihood


def eval_gp(model, x, y, device="cuda:0", dtype=torch.float32, lanc_iter=100, pre_size=100, cg_iter=1000, cg_tol=1e-2):
    model.eval()

    model2 = model

    with gpytorch.settings.eval_cg_tolerance(cg_tol), \
       gpytorch.settings.max_cg_iterations(cg_iter), \
       gpytorch.settings.max_preconditioner_size(pre_size), \
       gpytorch.settings.max_root_decomposition_size(lanc_iter), \
       gpytorch.settings.fast_pred_var(), torch.no_grad():

        # pred_y = model.likelihood(model(x))
        pred_y = model2(x)

        rmse = (pred_y.mean - y).pow(2).mean(0).sqrt()
        mae = (pred_y.mean - y).abs().mean(0)
        nll = - torch.distributions.Normal(
            pred_y.mean,
            pred_y.variance.add(model.likelihood.noise).sqrt()
        ).log_prob(y).mean()

    # # Set into eval mode
    # model.eval()

    # # Testing loop
    # test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
    # squared_errors = []
    # nlls = []
    # print("EVALUATING")
    # with gpytorch.settings.eval_cg_tolerance(1e-2), \
    #    gpytorch.settings.max_cg_iterations(1000), \
    #    gpytorch.settings.max_preconditioner_size(100), \
    #    gpytorch.settings.max_root_decomposition_size(100), \
    #    gpytorch.settings.fast_pred_var(), torch.no_grad():
    #     for test_x, test_y in tqdm(test_loader):
    #         test_x = test_x.to(dtype=dtype, device=device).contiguous()
    #         test_y = test_y.to(dtype=dtype, device=device).contiguous()
    #         print("HERE 1")
    #         pred_y = model(test_x)
    #         print("HERE 2")
    #         nll = - torch.distributions.Normal(
    #             pred_y.mean,
    #             pred_y.variance.add(model.likelihood.noise).sqrt()).log_prob(test_y).mean()
    #         print("HERE 3")
    #         se = torch.sum((pred_y.mean - test_y)**2)
    #         print("HERE 4")
    #         squared_errors += [se]
    #         nlls += [nll]
    #     rmse = torch.sqrt(torch.sum(torch.tensor(squared_errors)) / len(test_dataset))
    #     nll = torch.cat(nlls).mean()

    print("RMSE", rmse, rmse.dtype, "NLL", nll, "NOISE", model2.get_noise().item(), "LENGTHSCALE", model2.get_lengthscale(), "OUTPUTSCALE", model2.get_outputscale())
    return rmse, nll


# =============================================================================
# Quick Test
# =============================================================================

# if __name__ == "__main__":
#     from data.get_uci import ElevatorsDataset

#     # Get dataset
#     dataset = ElevatorsDataset("../../data/uci_datasets/uci_datasets/elevators/data.csv")
#     train_dataset, val_dataset, test_dataset = split_dataset(
#         dataset,
#         train_frac=CONFIG.dataset.train_frac,
#         val_frac=CONFIG.dataset.val_frac    
#     )

#     # Test
#     model, likelihood = train_gp(CONFIG, train_dataset, test_dataset)
    
#     # Evaluate
#     eval_gp(model, likelihood, test_dataset, device=CONFIG.model.device)
