import matplotlib.pyplot as plt
import umap

from diffusion_arithmetics.ddim_utils import generate_latents, generate_noises, generate_samples
from diffusion_arithmetics.models import get_openai_cifar

DIFFUSION_STEPS = 1000
T_RANGE = [25, 50, 100, 500, 1000]
NUMBER_OF_SAMPLES = 1
BATCH_SIZE = 1
SAVE_PATH = "umap_experiment.png"


model, diffusion, args = get_openai_cifar(steps=DIFFUSION_STEPS)
noises = generate_noises(NUMBER_OF_SAMPLES, args)
samples = generate_samples(
    random_noises=noises,
    number_of_samples=NUMBER_OF_SAMPLES,
    batch_size=BATCH_SIZE,
    diffusion_pipeline=diffusion,
    ddim_model=model,
    diffusion_args=args,
)
lats = []
for n_steps in T_RANGE:
    diffusion.num_timesteps = n_steps
    latents = generate_latents(
        ddim_generations=samples, batch_size=BATCH_SIZE, diffusion_pipeline=diffusion, ddim_model=model
    )
    lats.append(latents)

latents = [lat.squeeze(0) for lat in lats]
sample = samples[0]
noise = noises[0]

# Flatten the tensors
latents_flat = [latent.view(-1).numpy() for latent in latents]
sample_flat = sample.view(-1).numpy()
noise_flat = noise.cpu().view(-1).numpy()

# Combine latents and sample into one list for UMAP
data = latents_flat + [sample_flat] + [noise_flat]

# Apply UMAP
reducer = umap.UMAP()
embedding = reducer.fit_transform(data)

# Plotting
plt.figure(figsize=(10, 8))
plt.scatter(embedding[:-2, 0], embedding[:-2, 1], label="Latents", alpha=0.6)
plt.scatter(embedding[-2, 0], embedding[-2, 1], label="Sample", color="red", marker="x")
plt.scatter(embedding[-1, 0], embedding[-1, 1], label="Noise", color="green")
plt.legend()
plt.title("UMAP Projection of Latents and Sample")
plt.xlabel("UMAP Dimension 1")
plt.ylabel("UMAP Dimension 2")
plt.save(SAVE_PATH)
