import torch
import numpy as np
import argparse
from ddim_models import Unet, GaussianDiffusion, Trainer, normalize_to_01, sdedit_p_sample

def parse_args():
    parser = argparse.ArgumentParser(description="SDEdit sampling script")
    parser.add_argument("--sample_index", type=int, default=8)
    return parser.parse_args()

args = parse_args()
sample_index = args.sample_index

print("Sampling Configuration:")
for arg in vars(args):
    print(f"{arg}: {getattr(args, arg)}")

# --------------------------
# Load and preprocess data
# --------------------------
stimulus_latent_data = np.load('datasets/cropped_resized_images_relative_128_cleaned.npy')
stimulus_latent_data_normalized, data_min, data_max = normalize_to_01(torch.Tensor(stimulus_latent_data))
stimulus_latent_data_normalized = stimulus_latent_data_normalized.reshape(-1, 128, 128)
stimulus_latent_data_normalized = np.expand_dims(stimulus_latent_data_normalized, axis=1)

# --------------------------
# Set up model and diffusion
# --------------------------
timesteps = 150
model = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    flash_attn=False
)

diffusion = GaussianDiffusion(
    model,
    height=128,
    width=128,
    timesteps=timesteps,
    guidance_scale=250.0,
    positive_probing=False
)

trainer = Trainer(
    diffusion,
    stimulus_latent_data_normalized,
    train_batch_size=64,
    train_lr=2e-4,
    train_num_steps=20000,
    gradient_accumulate_every=2,
    ema_decay=0.98,
    amp=True,
    calculate_fid=False
)

trainer.load("20")

# --------------------------
# Prepare input for editing
# --------------------------
edit_t = 140
x_start = torch.Tensor(stimulus_latent_data_normalized[sample_index]).unsqueeze(0).repeat(128, 1, 1, 1)
x_start = x_start.to(diffusion.device)

# --------------------------
# Apply SDEdit sampling
# --------------------------
sampled_image = sdedit_p_sample(diffusion, x_start, edit_t=edit_t)
sampled_image = sampled_image.reshape(-1, 128, 128)

# --------------------------
# Denormalize and save
# --------------------------
sampled_images_denorm = sampled_image * (data_max - data_min) + data_min
torch.save(sampled_images_denorm, f'samples/face_sampled_image_{edit_t}_idx_{sample_index}.pt')
