import sys

sys.path.append("..")

import os

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision.transforms as transforms
from era5 import Era5
from skimage.metrics import structural_similarity
from torch.utils.data import Subset
from tqdm import tqdm
from util import *

from denoising_diffusion_pytorch.denoising_diffusion_pytorch import (
    DropoutDiffusion,
    DropoutUnet,
)


def get_results(y, model, M, N):
    cond = y.reshape(1, 1, R, R).repeat(N, 1, 1, 1).to(device)

    means = []
    variances = []
    samples = []
    seeds = torch.randint(0, 2**16, (M,), device=device).tolist()
    for seed in tqdm(seeds, total=len(seeds)):
        sample = model.p_sample_loop(cond.shape, cond, seed=seed)
        samples.append(sample)
        means.append(sample.mean(dim=0))
        variances.append(sample.var(dim=0))
    means = torch.stack(means).squeeze()
    variances = torch.stack(variances).squeeze()
    samples = torch.vstack(samples).squeeze()

    return samples, means, variances


def id_test(idx):
    print("id test")
    x, y = test_set[idx]
    gt = ((x + 1) / 2).squeeze().numpy()
    samples, means, variances = get_results(y, model, M, N)

    # Compute intermediate values
    mean = samples.mean(dim=0).detach().cpu().numpy()
    median = samples.median(dim=0).values.detach().cpu().numpy()
    mode = samples.mode(dim=0).values.detach().cpu().numpy()
    eu = means.var(dim=0).detach().cpu().squeeze()
    au = variances.mean(dim=0).detach().cpu().squeeze()
    tu = au + eu
    mean_err = (mean - gt) ** 2
    median_err = (median - gt) ** 2
    mode_err = (mode - gt) ** 2

    # Plot results
    sns.set(font_scale=1.5, font="serif")

    sns.heatmap(mean, cmap="coolwarm", vmin=0, vmax=1).set_title(
        "Mean Prediction", pad=15
    )
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_pred.png"), transparent=True)
    plt.clf()

    sns.heatmap(mean_err, cmap="coolwarm").set_title("Mean Error", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_mean.png"), transparent=True)
    plt.clf()

    sns.heatmap(median_err, cmap="coolwarm").set_title("Median Error", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_median.png"), transparent=True)
    plt.clf()

    sns.heatmap(mode_err, cmap="coolwarm").set_title("Mode Error", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_mode.png"), transparent=True)
    plt.clf()

    sns.heatmap(eu, cmap="coolwarm").set_title("$\\widehat{\\text{EU}}$", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_eu.png"), transparent=True)
    plt.clf()

    sns.heatmap(au, cmap="coolwarm").set_title("$\\widehat{\\text{AU}}$", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_au.png"), transparent=True)
    plt.clf()

    sns.heatmap(tu, cmap="coolwarm").set_title("$\\widehat{\\text{TU}}$", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_id_dropout_tu.png"), transparent=True)
    plt.clf()


def ood_test(idx):
    print("ood test")

    y = get_ood_measurement(R, idx=idx, noise_std=0.01)
    samples, means, variances = get_results(y, model, M, N)

    # Compute intermediate values
    mean = samples.mean(dim=0).detach().cpu().numpy()
    eu = means.var(dim=0).detach().cpu().squeeze()
    au = variances.mean(dim=0).detach().cpu().squeeze()
    tu = au + eu

    # Plot results
    sns.set(font_scale=1.5, font="serif")

    sns.heatmap(mean, cmap="coolwarm", vmin=0, vmax=1).set_title(
        "Mean Prediction", pad=15
    )
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_ood_dropout_pred.png"), transparent=True)
    plt.clf()

    sns.heatmap(eu, cmap="coolwarm").set_title("$\\widehat{\\text{EU}}$", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_ood_dropout_eu.png"), transparent=True)
    plt.clf()

    sns.heatmap(au, cmap="coolwarm").set_title("$\\widehat{\\text{AU}}$", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_ood_dropout_au.png"), transparent=True)
    plt.clf()

    sns.heatmap(tu, cmap="coolwarm").set_title("$\\widehat{\\text{TU}}$", pad=15)
    plt.axis("off")
    plt.savefig(os.path.join(WRITE_DIR, "era5_ood_dropout_tu.png"), transparent=True)
    plt.clf()


def stats_test(agg="mean"):
    print("stats")
    psnr = 0
    ssim = 0
    crps = 0
    fcrps = 0
    acrps = 0
    l1 = 0
    pbar = tqdm(test_set, total=len(test_set))
    for i, (x, y) in enumerate(pbar):
        gt = ((x + 1) / 2).squeeze().numpy()
        samples, _, _ = get_results(y, model, M, N)
        if agg == "mean":
            pred = samples.mean(dim=0).detach().cpu().numpy()
        elif agg == "median":
            pred = samples.median(dim=0).values.detach().cpu().numpy()
        elif agg == "mode":
            pred = samples.mode(dim=0).values.detach().cpu().numpy()
        else:
            raise ValueError(f"Invalid aggregation function: {agg}")
        psnr += PSNR(gt, pred)
        l1 += np.abs(gt - pred).mean()
        ssim += structural_similarity(gt, pred, data_range=1)
        tmp = get_crps(gt, samples)
        crps += tmp["crps"]
        fcrps += tmp["fcrps"]
        acrps += tmp["acrps"]
    data = dict(
        {
            "psnr": psnr / len(test_set),
            "ssim": ssim / len(test_set),
            "crps": crps / len(test_set),
            "fcrps": fcrps / len(test_set),
            "acrps": acrps / len(test_set),
            "l1": l1 / len(test_set),
        }
    )
    fields = ["psnr", "ssim", "crps", "fcrps", "acrps", "l1"]
    with open(os.path.join(WRITE_DIR, f"era5_dropout_stats_{agg}.csv"), "w") as f:
        f.write(",".join(fields) + "\n")
        f.write(",".join([str(data[field]) for field in fields]) + "\n")


if __name__ == "__main__":
    WRITE_DIR = "/XXXX-2/XXXX-1/scratch/XXXX-3-hyperdiffusion/figs"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    R = 128
    T = 100
    n_params = 2440241
    M = 10
    N = 100
    D = 100

    backbone = DropoutUnet(
        dim=16,
        dim_mults=(1, 2, 4, 8),
        channels=1,
        self_condition=True,
        dropout_rate=0.0,
    )
    model = DropoutDiffusion(backbone, image_size=R, timesteps=T).to(device)
    model.init_rng()
    model.load_state_dict(
        torch.load(
            f"/XXXX-2/XXXX-1/scratch/hyperdiffusion/era5/weights/dropout_0.pt"
        )
    )
    model.eval()

    # Test dataset
    tfm = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((R, R), antialias=True),
        ]
    )
    dset = Era5("/XXXX-2/XXXX-1/scratch/uq_diffusion/era5/data", tfm)
    test_set = Subset(dset, range(len(dset) - D, len(dset)))

    # id_test(0)
    ood_test(-1)
    # stats_test(agg="mean")
    # stats_test(agg="median")
    # stats_test(agg="mode")
