import matplotlib.pyplot as plt
import statsmodels.api as sm
import torch
import torch.distributions as dist
from scipy.stats import shapiro
from tqdm import tqdm


def normal_dist_test(dist_tens):
    flattened_data = dist_tens.flatten().numpy()

    _, shapiro_p_value = shapiro(flattened_data)

    return shapiro_p_value


def kl_div(dist1: torch.tensor, dist2: torch.tensor):
    dist1_mu = torch.tensor([dist1.mean()])
    dist1_sigma = torch.tensor([torch.std(dist1)])

    dist2_mu = torch.tensor([dist2.mean()])
    dist2_sigma = torch.tensor([torch.std(dist2)])

    P = dist.Normal(dist1_mu, dist1_sigma)
    Q = dist.Normal(dist2_mu, dist2_sigma)

    kl_divergence = dist.kl_divergence(P, Q)
    return kl_divergence.item()


def plot_dists(array1, array2, name_arr1, name_arr2, xlim=None, ylim=None, savepath=None):
    data1 = array1.flatten()
    data2 = array2.flatten()

    plt.figure(figsize=(10, 6))
    plt.rcParams.update(
        {
            "pgf.texsystem": "pdflatex",
            "font.family": "serif",
            "font.size": 15,
            "axes.labelsize": 15,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 12,
            "lines.linewidth": 2,
            "text.usetex": False,
            "pgf.rcfonts": False,
        }
    )
    plt.tight_layout(rect=[0, 0.03, 1, 0.85])
    plt.style.use("seaborn-v0_8-paper")
    plt.subplot(1, 1, 1)
    plt.hist(data1, bins=150, alpha=0.5, label=name_arr1, density=True)
    plt.hist(data2, bins=150, alpha=0.5, label=name_arr2, density=True)
    plt.legend(loc="upper right", fontsize=24)
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)

    plt.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

    plt.tight_layout()
    if savepath is None:
        plt.show()
    else:
        plt.savefig(savepath)


def stat_test_normality(latents, limit=5_000, p_val=0.05):
    n_tests = latents.shape[0]
    n_rej = 0
    n_acc = 0
    for lat_idx in tqdm(range(n_tests)):
        lat = latents[lat_idx].flatten()
        n_latent_samples = min(lat.shape[0], limit)
        indices = torch.randint(0, lat.shape[0], (n_latent_samples,))
        test_p = normal_dist_test(lat[indices])
        if test_p > p_val:
            n_acc += 1
        else:
            n_rej += 1
    return {"n_rej": n_rej, "%_rej": n_rej / n_tests, "n_acc": n_acc, "%_acc": n_acc / n_tests}


def plt_qq(noise_tens, lat_tens, ds_name, t, diff_type, path=None):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    sm.qqplot_2samples(noise_tens.cpu().flatten().numpy(), lat_tens.cpu().flatten().numpy(), ax=ax, line="s")
    plt.title(f"{diff_type} | {ds_name} | T={t}", fontsize=25)
    plt.xlabel("Noise", fontsize=20)
    plt.ylabel(f"Latent ($T={t}$)", fontsize=20)
    plt.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
    if path is None:
        plt.show()
    else:
        plt.savefig(path)


if __name__ == "__main__":
    # cifar
    cifar_noise = torch.load("experiments/cifar_outs/T_10/noise.pt", weights_only=False)
    cifar_l_T10 = torch.load("experiments/cifar_outs/T_10/latent.pt", weights_only=False)
    plt_qq(cifar_noise, cifar_l_T10, "Cifar-10", 10, "DDPM")
    cifar_l_T100 = torch.load("experiments/cifar_outs/T_100/latent.pt", weights_only=False)
    plt_qq(cifar_noise, cifar_l_T100, "Cifar-10", 100, "DDPM")
    cifar_l_T1000 = torch.load("experiments/cifar_outs/T_1000/latent.pt", weights_only=False)
    plt_qq(cifar_noise, cifar_l_T1000, "Cifar-10", 1000, "DDPM")
    cifar_l_T4000 = torch.load("experiments/cifar_outs/T_4000/latent.pt", weights_only=False)
    plt_qq(cifar_noise, cifar_l_T4000, "Cifar-10", 4000, "DDPM")

    # imagenet
    imgnet_noise = torch.load("experiments/imagenet_outs/T_10/noise.pt", weights_only=False)
    imgnet_l_T10 = torch.load("experiments/imagenet_outs/T_10/latent.pt", weights_only=False)
    plt_qq(imgnet_noise, imgnet_l_T10, "ImageNet", 10, "DDPM")
    plot_dists(
        imgnet_noise.cpu(),
        imgnet_l_T10.cpu(),
        "Noise",
        "Latent",
        xlim=(-4, 4),
        ylim=(0, 1.6),
        savepath="experiments/T_outs/imagenet_outs/T_10/dists_plot.png",
    )
    imgnet_l_T100 = torch.load("experiments/imagenet_outs/T_100/latent.pt", weights_only=False)
    plt_qq(imgnet_noise, imgnet_l_T100, "ImageNet", 100, "DDPM")
    plot_dists(
        imgnet_noise.cpu(),
        imgnet_l_T100.cpu(),
        "Noise",
        "Latent",
        xlim=(-4, 4),
        ylim=(0, 1.6),
        savepath="experiments/T_outs/imagenet_outs/T_100/dists_plot.png",
    )
    imgnet_l_T1000 = torch.load("experiments/imagenet_outs/T_1000/latent.pt", weights_only=False)
    plt_qq(imgnet_noise, imgnet_l_T1000, "ImageNet", 1000, "DDPM")
    plot_dists(
        imgnet_noise.cpu(),
        imgnet_l_T1000.cpu(),
        "Noise",
        "Latent",
        xlim=(-4, 4),
        ylim=(0, 1.6),
        savepath="experiments/T_outs/imagenet_outs/T_1000/dists_plot.png",
    )

    imgnet_l_T4000 = torch.load("experiments/T_outs/imagenet_outs/T_4000/latents.pt", weights_only=False)
    plt_qq(
        imgnet_noise,
        imgnet_l_T4000,
        "ImageNet",
        4000,
        "DDPM",
        path="experiments/T_outs/imagenet_outs/T_4000/qq_plot.png",
    )
    plot_dists(
        imgnet_noise.cpu(),
        imgnet_l_T4000.cpu(),
        "Noise",
        "Latent",
        xlim=(-4, 4),
        ylim=(0, 1.6),
        savepath="experiments/T_outs/imagenet_outs/T_4000/dists_plot.png",
    )

    # celeba
    ldm_noise = torch.load("experiments/ldm_outs/T_10/noise.pt", weights_only=False)
    ldm_l_T10 = torch.load("experiments/ldm_outs/T_10/latent.pt", weights_only=False)
    plt_qq(ldm_noise, ldm_l_T10, "CelebA", 10, "LDM")
    ldm_l_T100 = torch.load("experiments/ldm_outs/T_100/latent.pt", weights_only=False)
    plt_qq(ldm_noise, ldm_l_T100, "CelebA", 100, "LDM")
    ldm_l_T1000 = torch.load("experiments/ldm_outs/T_1000/latent.pt", weights_only=False)
    plt_qq(ldm_noise, ldm_l_T1000, "CelebA", 1000, "LDM")
