# System/Library imports
import time
from typing import *

# Common data science imports
import hydra
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Dataset

# For logging
try:
    import wandb
except:
    pass

# Gpytorch imports
import gpytorch
from linear_operator.settings import max_cholesky_size

# Our imports
from gp.softki.model import SoftGP
from gp.util import dynamic_instantiation, flatten_dict, flatten_dataset, split_dataset, filter_param, heatmap


# =============================================================================
# Train and Evaluate
# =============================================================================

def train_gp(config: DictConfig, train_dataset: Dataset, test_dataset: Dataset) -> SoftGP:
    # Unpack dataset
    dataset_name = config.dataset.name

    # Unpack model configuration
    kernel, use_ard, use_scale, num_inducing, induce_init, dtype, device, noise, learn_noise, solver, cg_tolerance, mll_approx, fit_chunk_size, use_qr = (
        dynamic_instantiation(config.model.kernel),
        config.model.use_ard,
        config.model.use_scale,
        config.model.num_inducing,
        config.model.induce_init,
        getattr(torch, config.model.dtype),
        config.model.device,
        config.model.noise,
        config.model.learn_noise,
        config.model.solver,
        config.model.cg_tolerance,
        config.model.mll_approx,
        config.model.fit_chunk_size,
        config.model.use_qr,
    )
    if use_ard:
        config.model.kernel.ard_num_dims = train_dataset.dim
        kernel = dynamic_instantiation(config.model.kernel)

    use_T, T, learn_T, min_T, use_threshold, threshold, learn_threshold, topk, num_probes = (
        config.model.use_T,
        config.model.T,
        config.model.learn_T,
        config.model.min_T,
        config.model.use_threshold,
        config.model.threshold,
        config.model.learn_threshold,
        config.model.topk,
        config.model.num_probes,
    )

    # Unpack training configuration
    seed, batch_size, epochs, lr = (
        config.training.seed,
        config.training.batch_size,
        config.training.epochs,
        config.training.learning_rate,
    )

    # Set wandb
    if config.wandb.watch:
        # Create wandb config with training/model config
        config_dict = flatten_dict(OmegaConf.to_container(config, resolve=True))

        # Create name
        rname = f"softki_{config.wandb.group}_{dataset_name}_{num_inducing}_{noise}_{seed}"
        
        # Initialize wandb
        wandb.init(
            project=config.wandb.project,
            entity=config.wandb.entity,
            group=config.wandb.group,
            name=rname,
            config=config_dict
        )

    # Initialize inducing points with kmeans
    train_features, train_labels = flatten_dataset(train_dataset)
    if induce_init == "kmeans":
        print("Using kmeans ...")
        kmeans = KMeans(n_clusters=num_inducing)
        kmeans.fit(train_features)
        centers = kmeans.cluster_centers_
        interp_points = torch.tensor(centers).to(dtype=dtype, device=device)
    else:
        print("Using random ...")
        interp_points = torch.rand(num_inducing, train_dataset.dim).to(device=device)
    
    # Setup model
    model = SoftGP(
        kernel,
        interp_points,
        dtype=dtype,
        device=device,
        noise=noise,
        learn_noise=learn_noise,
        use_T=use_T,
        T=T,
        learn_T=learn_T,
        min_T=min_T,
        use_threshold=use_threshold,
        threshold=threshold,
        learn_threshold=learn_threshold,
        use_scale=use_scale,
        solver=solver,
        cg_tolerance=cg_tolerance,
        mll_approx=mll_approx,
        fit_chunk_size=fit_chunk_size,
        use_qr=use_qr,
        topk=topk,
        use_dot=config.model.use_dot,
    )

    # Setup optimizer for hyperparameters
    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam(params, lr=lr)

    # Training loop
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=config.dataset.num_workers)
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    for epoch in pbar:        
        # Perform an epoch of fitting hyperparameters (including interpolation points)
        elapsed_time = 0
        neg_mlls = []
        for x_batch, y_batch in train_loader:
            t1 = time.perf_counter()
            # Load batch
            x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
            y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)
            
            # Perform optimization
            optimizer.zero_grad()
            with gpytorch.settings.max_root_decomposition_size(100), max_cholesky_size(int(1.e7)):
                neg_mll = -model.mll(x_batch, y_batch)
            neg_mlls += [neg_mll.item()]
            neg_mll.backward()
            optimizer.step()
            torch.cuda.synchronize()
            elapsed_time += time.perf_counter() - t1

            # Log
            pbar.set_description(f"Epoch {epoch+1}/{epochs}")
            pbar.set_postfix(MLL=f"{neg_mll.item()}")

        # Solve for weights given fixed interpolation points
        t2 = time.perf_counter()
        use_pinv = model.fit(train_features, train_labels)
        t3 = time.perf_counter()

        # Evaluate gp
        results = eval_gp(model, test_dataset, device=device, num_workers=config.dataset.num_workers)

        # Record
        if config.wandb.watch:
            K_zz = model._mk_cov(model.interp_points).detach().cpu().numpy()
            custom_bins = [0, 1e-20, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 0.5, 20]
            hist = np.histogram(K_zz.flatten(), bins=custom_bins)
            results = {
                "loss": torch.tensor(neg_mlls).mean(),
                "test_rmse": results["rmse"],
                "test_nll": results["nll"],
                "epoch_time": elapsed_time,
                "fit_time": t3 - t2,
                "noise": model.noise.cpu(),
                "lengthscale": wandb.Histogram(model.get_lengthscale().detach().cpu()),
                "outputscale": model.get_outputscale(),
                "threshold": model.threshold.cpu().item(),
                "T": wandb.Histogram(model.T.detach().cpu()),
                # "K_zz_bins": wandb.Histogram(np_histogram=hist),
                "K_zz_norm_2": np.linalg.norm(K_zz, ord='fro'),
                "K_zz_norm_1": np.linalg.norm(K_zz, ord=1),
                "K_zz_norm_inf": np.linalg.norm(K_zz, ord=np.inf),
                "avg_eles_0.9": results["avg_eles_0.9"],
                "std_eles_0.9": results["std_eles_0.9"],
            }
            for cnt, edge in zip(hist[0], hist[1]):
                results[f"K_zz_bin_{edge}"] = cnt

            def save_parameters():
                artifact = wandb.Artifact(f"inducing_points_{rname}_{epoch}", type="parameters")
                np.save("array.npy", model.interp_points.detach().cpu().numpy()) 
                artifact.add_file("array.npy")
                wandb.log_artifact(artifact)

                artifact = wandb.Artifact(f"K_zz_{rname}_{epoch}", type="parameters")
                np.save("K_zz.npy", K_zz) 
                artifact.add_file("K_zz.npy")
                wandb.log_artifact(artifact)

            if epoch % 10 == 0 or epoch == epochs - 1:
                img = heatmap(K_zz)
                results.update({
                    "inducing_points": wandb.Histogram(model.interp_points.detach().cpu().numpy()),
                    "K_zz": wandb.Image(img)
                })

            if epoch == epochs - 1:
                save_parameters()
            
            wandb.log(results)

    return model


def eval_gp(model: SoftGP, test_dataset: Dataset, device="cuda:0", num_workers=0) -> float:
    preds = []
    neg_mlls = []
    batch_size = 256 if len(test_dataset) > 50000 else 1024
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    num_eles = []
    for x_batch, y_batch in tqdm(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        pred_mean = model.pred(x_batch)
        W_star_z = model._interp(x_batch)
        sorted_vals, sorted_indices = torch.sort(W_star_z, dim=1, descending=True)
        cumsum = torch.cumsum(sorted_vals, dim=1)

        mask = cumsum >= 0.9
        first_idx = mask.float().cumsum(dim=1).eq(1).float().argmax(dim=1)
        num_eles += [first_idx + 1]

        preds += [(pred_mean - y_batch).detach().cpu()**2]
        covar = model.pred_cov(x_batch)
        nll = -torch.distributions.Normal(pred_mean, torch.sqrt(covar.diag())).log_prob(y_batch).detach().cpu()
        neg_mlls += [nll]
    rmse = torch.sqrt(torch.sum(torch.cat(preds)) / len(test_dataset)).item()
    neg_mll = torch.cat(neg_mlls).mean()
    tmp = torch.cat(num_eles).float()
    avg_eles = tmp.mean()
    std_eles = tmp.std()
            
    print("RMSE:", rmse, "NEG_MLL", neg_mll.item(), "avg", avg_eles, "std", std_eles, "NOISE", model.noise.cpu().item(), "LENGTHSCALE", model.get_lengthscale(), "OUTPUTSCALE", model.get_outputscale(), "THRESHOLD", model.threshold.cpu().item(), "T", model.T.cpu())
    
    return {
        "rmse": rmse,
        "nll": neg_mll,
        "avg_eles_0.9": avg_eles,
        "std_eles_0.9": std_eles,
    }


# =============================================================================
# Quick test
# =============================================================================

@hydra.main(version_base=None, config_path="./", config_name="config")
def main(config):
    OmegaConf.set_struct(config, False)
    from data.get_uci import PoleteleDataset
    
    data_dir = "../../data/uci_datasets/uci_datasets"
    dataset = PoleteleDataset(f"{data_dir}/pol/data.csv")
    config.dataset = {}
    config.dataset.name = "pol"
    config.dataset.num_workers = 0
    config.wandb = {}
    config.wandb.watch = False

    train_dataset, val_dataset, test_dataset = split_dataset(
        dataset,
        train_frac=0.9,
        val_frac=0.0,
    )

    # Train
    model = train_gp(config, train_dataset, test_dataset)
    
    # Eval
    eval_gp(model, test_dataset, device=config.model.device)


if __name__ == "__main__":
    main()
    