import os

import torch

from diffusion_arithmetics.ldms.models import get_inv_scheduler, get_ldm_celeba, get_scheduler
from diffusion_arithmetics.ldms.sample_ldm import generate_latents, generate_samples, get_noises
from diffusion_arithmetics.normality_utils import kl_div, plot_dists, plt_qq, stat_test_normality

N_SAMPLES = 1024
BATCH_SIZE = 64
SEED = 420
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_DIR_PATH = "experiments/ldm_outs/T_{}/"

TS = [10, 100, 1000]


def save_results(noises, samples, latents, samples2, T):
    dir_path = SAVE_DIR_PATH.format(T)
    os.makedirs(dir_path, exist_ok=True)
    torch.save(noises, dir_path + "noise.pt")
    torch.save(samples, dir_path + "samples.pt")
    torch.save(latents, dir_path + "latents.pt")
    torch.save(samples2, dir_path + "samples2.pt")


def experiment():
    noises = get_noises(n_samples=N_SAMPLES, seed=SEED)
    ldm_unet, _ = get_ldm_celeba(device=DEVICE)
    for T in TS:
        print(f"Starting for {T=}")
        scheduler = get_scheduler(T=T)
        inv_scheduler = get_inv_scheduler(T=T)

        samples = generate_samples(
            noise=noises, diffusion_unet=ldm_unet, diffusion_scheduler=scheduler, batch_size=BATCH_SIZE, device=DEVICE
        )
        latents = generate_latents(
            samples=samples,
            diffusion_unet=ldm_unet,
            diffusion_scheduler=inv_scheduler,
            batch_size=BATCH_SIZE,
            device=DEVICE,
        )
        samples2 = generate_samples(
            noise=latents, diffusion_unet=ldm_unet, diffusion_scheduler=scheduler, batch_size=BATCH_SIZE, device=DEVICE
        )
        save_results(noises, samples, latents, samples2, T)

        print(f"{stat_test_normality(latents)=}")
        print(f"{kl_div(noises, latents)=}")
        plot_dists(
            noises,
            latents,
            "Noise",
            f"Latent T={T}",
            xlim=(-4, 4),
            ylim=(0, 0.5),
            savepath=(SAVE_DIR_PATH.format(T) + "plot_dists.pdf"),
        )
        plt_qq(noises, latents, "CelebA", T, "LDM", path=(SAVE_DIR_PATH.format(T) + "plot_qq.pdf"))


if __name__ == "__main__":
    experiment()
