import torch
import numpy as np
from ddim_models import Unet, GaussianDiffusion, Trainer, normalize_to_01

# Load and normalize stimulus data
stimulus_latent_data = np.load('datasets/cropped_resized_images_relative_128_cleaned.npy')
stimulus_latent_data_normalized, data_min, data_max = normalize_to_01(torch.Tensor(stimulus_latent_data))

stimulus_latent_data_normalized = stimulus_latent_data_normalized.reshape(-1, 128, 128)
stimulus_latent_data_normalized = np.expand_dims(stimulus_latent_data_normalized, axis=1)

# Set diffusion parameters
timesteps = 150

model = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    flash_attn=False
)

diffusion = GaussianDiffusion(
    model,
    height=128,
    width=128,
    timesteps=timesteps,
    objective='pred_noise',
)

trainer = Trainer(
    diffusion,
    stimulus_latent_data_normalized,
    train_batch_size=64,
    train_lr=2e-4,
    train_num_steps=20000,
    gradient_accumulate_every=2,
    ema_decay=0.98,
    amp=True,
    calculate_fid=False,
)

trainer.train()

# Sample from the trained model
sampled_image = diffusion.sample(batch_size=256)
sampled_image = sampled_image.reshape(-1, 128, 128)

# Denormalize before saving
sampled_images_denorm = sampled_image * (data_max - data_min) + data_min
torch.save(sampled_images_denorm, 'samples/dpm_pred_x0_sampled_images.pt')
