import sys

sys.path.append("..")

import os

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision.transforms as transforms
from luna import Luna
from skimage.metrics import structural_similarity
from torch.utils.data import Subset
from tqdm import tqdm
from util import *

from hyperdiffusion import FrozenUnet, HyperDiffusion


def get_results(y, model, M, N):
    cond = y.reshape(1, 1, R, R).repeat(N, 1, 1, 1).to(device)

    means = []
    variances = []
    samples = []
    for i in tqdm(range(M)):
        # Generate weights
        weight = torch.randn(model.in_dim, device=device)
        sample = model.p_sample_loop(cond.shape, cond, in_vec=weight)
        means.append(sample.mean(dim=0))
        variances.append(sample.var(dim=0))
        samples.append(sample)
    means = torch.stack(means).squeeze()
    variances = torch.stack(variances).squeeze()
    samples = torch.vstack(samples).squeeze()

    return samples, means, variances


def id_test(idx):
    print("id test")
    x, y = test_set[idx]
    gt = ((x + 1) / 2).squeeze().numpy()
    samples, means, variances = get_results(y, model, M, N)

    # Compute intermediate values
    mean = samples.mean(dim=0).detach().cpu().numpy()
    median = samples.median(dim=0).values.detach().cpu().numpy()
    mode = samples.mode(dim=0).values.detach().cpu().numpy()
    eu = means.var(dim=0).detach().cpu().squeeze()
    au = variances.mean(dim=0).detach().cpu().squeeze()
    tu = au + eu
    mean_err = (mean - gt) ** 2
    median_err = (median - gt) ** 2
    mode_err = (mode - gt) ** 2

    # Plot results
    sns.set(font_scale=1.5, font="serif")

    sns.heatmap(mean, cmap="coolwarm", vmin=0, vmax=1).set_title(
        "Mean Prediction", pad=15
    )
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_pred.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(mean_err, cmap="coolwarm").set_title("Mean Error", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_mean.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(median_err, cmap="coolwarm").set_title("Median Error", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_median.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(mode_err, cmap="coolwarm").set_title("Mode Error", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_mode.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(eu, cmap="coolwarm").set_title("$\\widehat{\\text{EU}}$", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_eu.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(au, cmap="coolwarm").set_title("$\\widehat{\\text{AU}}$", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_au.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(tu, cmap="coolwarm").set_title("$\\widehat{\\text{TU}}$", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_id_hyperdiffusion_tu.png"), transparent=True
    )
    plt.clf()


def ood_test(idx):
    print("ood test")

    y = get_ood_measurement(R, idx=idx)
    samples, means, variances = get_results(y, model, M, N)

    # Compute intermediate values
    mean = samples.mean(dim=0).detach().cpu().numpy()
    eu = means.var(dim=0).detach().cpu().squeeze()
    au = variances.mean(dim=0).detach().cpu().squeeze()
    tu = au + eu

    # Plot results
    sns.set(font_scale=1.5, font="serif")

    sns.heatmap(mean, cmap="coolwarm", vmin=0, vmax=1).set_title(
        "Mean Prediction", pad=15
    )
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_ood_hyperdiffusion_pred.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(eu, cmap="coolwarm").set_title("$\\widehat{\\text{EU}}$", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_ood_hyperdiffusion_eu.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(au, cmap="coolwarm").set_title("$\\widehat{\\text{AU}}$", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_ood_hyperdiffusion_au.png"), transparent=True
    )
    plt.clf()

    sns.heatmap(tu, cmap="coolwarm").set_title("$\\widehat{\\text{TU}}$", pad=15)
    plt.axis("off")
    plt.savefig(
        os.path.join(WRITE_DIR, "luna_ood_hyperdiffusion_tu.png"), transparent=True
    )
    plt.clf()


def stats_test(agg="mean"):
    print("stats")
    psnr = 0
    ssim = 0
    crps = 0
    fcrps = 0
    acrps = 0
    l1 = 0
    pbar = tqdm(test_set, total=len(test_set))
    for i, (x, y) in enumerate(pbar):
        gt = ((x + 1) / 2).squeeze().numpy()
        samples, _, _ = get_results(y, model, M, N)
        if agg == "mean":
            pred = samples.mean(dim=0).detach().cpu().numpy()
        elif agg == "median":
            pred = samples.median(dim=0).values.detach().cpu().numpy()
        elif agg == "mode":
            pred = samples.mode(dim=0).values.detach().cpu().numpy()
        else:
            raise ValueError(f"Invalid aggregation function: {agg}")
        psnr += PSNR(gt, pred)
        l1 += np.abs(gt - pred).mean()
        ssim += structural_similarity(gt, pred, data_range=1)
        tmp = get_crps(gt, samples)
        crps += tmp["crps"]
        fcrps += tmp["fcrps"]
        acrps += tmp["acrps"]
    data = dict(
        {
            "psnr": psnr / len(test_set),
            "ssim": ssim / len(test_set),
            "crps": crps / len(test_set),
            "fcrps": fcrps / len(test_set),
            "acrps": acrps / len(test_set),
            "l1": l1 / len(test_set),
        }
    )
    fields = ["psnr", "ssim", "crps", "fcrps", "acrps", "l1"]
    with open(
        os.path.join(WRITE_DIR, f"luna_hyperdiffusion_stats_{agg}.csv"), "w"
    ) as f:
        f.write(",".join(fields) + "\n")
        f.write(",".join([str(data[field]) for field in fields]) + "\n")


if __name__ == "__main__":
    WRITE_DIR = "/XXXX-2/XXXX-1/scratch/XXXX-3-hyperdiffusion/figs"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    R = 128
    T = 100
    n_params = 2440241
    M = 10
    N = 100
    D = 100

    # Load model weights
    backbone = FrozenUnet(
        dim=16, dim_mults=(1, 2, 4, 8), channels=1, self_condition=True, device=device
    )
    model = HyperDiffusion(backbone, image_size=R, timesteps=T, n_params=n_params).to(
        device
    )
    model.load_state_dict(
        torch.load(
            f"/XXXX-2/XXXX-1/scratch/uq_diffusion/luna/notebook/tmp/hyperddpm_0.pt"
        )
    )
    model.eval()

    # Test dataset
    tfm = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((R, R), antialias=True),
        ]
    )
    dset = Luna("/XXXX-2/XXXX-1/scratch/uq_diffusion/luna/data", tfm)
    test_set = Subset(dset, range(len(dset) - D, len(dset)))

    id_test(-1)
    ood_test(-1)
    # stats_test(agg="mean")
    # stats_test(agg="median")
    # stats_test(agg="mode")
