import torch

from diffusion_arithmetics.ddim_utils import generate_latents, generate_noises, generate_samples
from diffusion_arithmetics.metrics import calc_angles, calc_distances
from diffusion_arithmetics.models import get_kd_cifar100, get_kd_small_cifar10, get_openai_cifar, get_openai_imagenet

EXPERIMENT_DIR = "experiments/angles"
device = "cuda" if torch.cuda.is_available() else "cpu"

STEPS = 4000
NUMBER_OF_SAMPLES = 2 * 1024
BATCH_SIZE = 128
MODELS = {
    "openai_cifar10": get_openai_cifar(steps=STEPS, device="cuda"),
    "openai_imagenet": get_openai_imagenet(steps=STEPS, device="cuda"),
    "kd_cifar10_limit32": get_kd_small_cifar10(
        steps=STEPS,
        model_path="res/DDGM_collapse/CIFAR_10_limit_32/model050000_0.pt",
        device="cpu",
    ),
    "kd_cifar10_limit2048": get_kd_small_cifar10(
        steps=STEPS,
        model_path="res/DDGM_collapse/CIFAR_10_limit_2048/model050000_0.pt",
        device="cpu",
    ),
    "kd_cifar10_limit4096": get_kd_small_cifar10(
        steps=STEPS,
        model_path="res/DDGM_collapse/CIFAR_10_limit_4096/model050000_0.pt",
        device="cpu",
    ),
    "kd_cifar100": get_kd_cifar100(steps=STEPS, device="cpu"),
}


for model_name, (model, diffusion, args) in MODELS.items():
    print(f"** RUNNING MODEL: {model_name} **")

    path_prefix = EXPERIMENT_DIR + "/" + model_name + "_"

    model = model.to("cuda")
    noises = generate_noises(NUMBER_OF_SAMPLES, args)
    torch.save(noises.cpu(), path_prefix + "noise.pth")
    samples = generate_samples(
        random_noises=noises,
        number_of_samples=NUMBER_OF_SAMPLES,
        batch_size=BATCH_SIZE,
        diffusion_pipeline=diffusion,
        ddim_model=model,
        diffusion_args=args,
    )
    torch.save(samples.cpu(), path_prefix + "samples.pth")
    latents = generate_latents(
        ddim_generations=samples, batch_size=BATCH_SIZE, diffusion_pipeline=diffusion, ddim_model=model
    )
    torch.save(latents.cpu(), path_prefix + "latents.pth")
    samples2 = generate_samples(
        random_noises=latents.to(device),
        number_of_samples=NUMBER_OF_SAMPLES,
        batch_size=BATCH_SIZE,
        diffusion_pipeline=diffusion,
        ddim_model=model,
        diffusion_args=args,
    )
    torch.save(noises.cpu(), path_prefix + "samples_from_latents.pth")

    angles_stats = calc_angles(noises=noises.cpu(), samples=samples.cpu(), latents=latents.cpu())
    dist_stats = calc_distances(
        noise=noises.cpu(), sample_from_noise=samples.cpu(), latent=latents.cpu(), sample_from_latent=samples2.cpu()
    )

    print("Statistics:")
    print({**angles_stats, **dist_stats})
