# System imports
import time
from typing import *

# Gpytorch / Torch
import gpytorch
import gpytorch.constraints
from gpytorch.constraints import GreaterThan
import torch
from torch.utils.data import DataLoader

# Other
from omegaconf import OmegaConf
from tqdm import tqdm 

try:
    import wandb 
except:
    pass

# Our imports
from gp.exact.mll import CGDMLL
from gp.exact.model import ExactGPModel, KeOpsExactGPModel
from gp.util import dynamic_instantiation, flatten_dict, flatten_dataset


# =============================================================================
# Train / Eval
# =============================================================================

def train_gp(config, train_dataset, test_dataset):
    # Unpack dataset
    dataset_name = config.dataset.name

    # Unpack model configuration
    kernel, use_ard, use_scale, dtype, device, noise, noise_constraint, learn_noise, max_cg_iters, cg_tolerance, use_keops = (
        dynamic_instantiation(config.model.kernel),
        config.model.use_ard,
        config.model.use_scale,
        getattr(torch, config.model.dtype),
        config.model.device,
        config.model.noise,
        config.model.noise_constraint,
        config.model.learn_noise,
        config.model.max_cg_iters,
        config.model.cg_tolerance,
        config.model.get('use_keops', False),
    )
    if use_ard:
        config.model.kernel.ard_num_dims = train_dataset.dim
        kernel = dynamic_instantiation(config.model.kernel)

    # Unpack training configuration
    seed, epochs, lr = (
        config.training.seed,
        config.training.epochs,
        config.training.learning_rate,
    )

    # Set wandb
    if config.wandb.watch:
        # Create wandb config with training/model config
        config_dict = flatten_dict(OmegaConf.to_container(config, resolve=True))

        # Create name
        rname = f"exact_{dataset_name}_{noise}_{seed}"
        
        # Initialize wandb
        wandb.init(
            project=config.wandb.project,
            entity=config.wandb.entity,
            group=config.wandb.group,
            name=rname,
            config=config_dict
        )
    
    print("Setting dtype to ...", dtype)
    torch.set_default_dtype(dtype)

    # Dataset preparation
    train_x, train_y = flatten_dataset(train_dataset)
    train_x = train_x.to(dtype=dtype, device=device)
    train_y = train_y.to(dtype=dtype, device=device)

    if use_keops:
        test_x, test_y = flatten_dataset(test_dataset)
        test_x = test_x.to(dtype=dtype, device=device)
        test_y = test_y.to(dtype=dtype, device=device)

    # Model
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=GreaterThan(noise_constraint)).to(device=device)
    likelihood.noise = torch.tensor([noise]).to(device=device)
    # likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device=device)

    # Choose model based on use_keops flag
    if use_keops:
        # Extract kernel type from config
        kernel_type = config.model.get('keops_kernel_type', 'rbf')
        nu = config.model.kernel.get('nu', 1.5) if hasattr(config.model.kernel, 'nu') else 1.5
        ard_num_dims = train_dataset.dim if use_ard else None
        model = KeOpsExactGPModel(train_x, train_y, likelihood, kernel_type=kernel_type, nu=nu, use_scale=use_scale, ard_num_dims=ard_num_dims).to(device=device)
        test_likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=GreaterThan(noise_constraint)).to(device=device)
        test_model = KeOpsExactGPModel(test_x, test_y, test_likelihood, kernel_type=kernel_type, nu=nu, use_scale=use_scale, ard_num_dims=ard_num_dims).to(device=device)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    else:
        model = ExactGPModel(train_x, train_y, likelihood, kernel=kernel, use_scale=use_scale).to(device=device)
        mll = CGDMLL(likelihood, model, max_cg_iters=max_cg_iters, cg_tolerance=cg_tolerance)

    # Training parameters
    model.train()
    likelihood.train()
    test_likelihood.eval()
    test_model.eval()

    # Set optimizer
    if learn_noise:
        params = model.parameters()
    else:
        params = model.parameters()
    optimizer = torch.optim.Adam([
        {'params': params},
    ], lr=lr)
    
    # Training loop
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    for epoch in pbar:
        t1 = time.perf_counter()

        # Load batch
        optimizer.zero_grad()
        if use_keops:
            output = model(train_x)
        else:
            output = likelihood(model(train_x))
        loss = -mll(output, train_y)
        loss.backward()

        # step optimizers
        optimizer.step()
        t2 = time.perf_counter()

        # Log
        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix(MLL=f"{-loss.item()}")

        # Evaluate
        if use_keops:
            test_model.load_state_dict(model.state_dict())
            output2 = test_model(test_x)
            means = output2.mean.detach()
            se = torch.sum((means - test_y)**2)
            test_rmse = torch.sqrt(se / len(test_dataset))
            if True: # epoch == epochs - 1:
                variances = output2.covariance_matrix.diagonal().detach()
                stds = variances.sqrt()
                try:
                    test_nll = -torch.distributions.Normal(means, stds).log_prob(test_y).mean()
                except Exception as e:
                    print(e)
                    test_nll = 0
            else:
                test_nll = 0
        else:
            test_rmse, test_nll = eval_gp(model, likelihood, test_dataset, device=device)
            model.train()
            likelihood.train()

        if config.wandb.watch:
            results = {
                "loss": loss,
                "test_nll": test_nll,
                "test_rmse": test_rmse,
                "epoch_time": t2 - t1,
                "noise": model.get_noise(),
                "lengthscale": model.get_lengthscale(),
                "outputscale": model.get_outputscale(),
            }
            wandb.log(results)
        
    return model, likelihood


def eval_gp(model, likelihood, test_dataset, device="cuda:0"):
    # Set into eval mode
    model.eval()
    likelihood.eval()

    # Testing loop
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
    squared_errors = []
    nlls = []
    
    # Check if using KeOps model
    is_keops = isinstance(model, KeOpsExactGPModel)
    
    with torch.no_grad():
        for test_x, test_y in tqdm(test_loader):
            test_x = test_x.to(device=device)
            output = likelihood(model(test_x))
            means = output.mean.detach().cpu()
            
            if is_keops:
                # For KeOps, compute variance differently to avoid diagonal issues
                # Use the covariance matrix diagonal directly
                variances = output.covariance_matrix.diagonal().detach().cpu()
                stds = variances.sqrt()
            else:
                stds = output.variance.sqrt().detach().cpu()
                
            nll = -torch.distributions.Normal(means, stds).log_prob(test_y).mean()
            se = torch.sum((means - test_y)**2)
            squared_errors += [se]
            nlls += [nll]
            
    rmse = torch.sqrt(torch.sum(torch.tensor(squared_errors)) / len(test_dataset))
    nll = torch.sum(torch.tensor(nll))

    print("RMSE", rmse, rmse.dtype, "NLL", nll, "NOISE", model.get_noise().item(), "LENGTHSCALE", model.get_lengthscale(), "OUTPUTSCALE", model.get_outputscale())
    return rmse, nll
