import torch
import numpy as np
import random
import argparse
from ddim_models import Unet, GaussianDiffusion, Trainer, normalize_to_01

# -----------------------------
# Utils
# -----------------------------
def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def parse_args():
    parser = argparse.ArgumentParser(description="CLG-DDIM Sampling Script")
    parser.add_argument("--generation_index", type=int, default=0)
    parser.add_argument("--guidance_scale", type=float, default=0.0)
    return parser.parse_args()

# -----------------------------
# Parse arguments
# -----------------------------
args = parse_args()
config = vars(args)
generation_index = config["generation_index"]
guidance_scale = config["guidance_scale"]

print("Sampling Configuration:")
for k, v in config.items():
    print(f"{k}: {v}")

# -----------------------------
# Load & preprocess data
# -----------------------------
stimulus_latent_data = np.load('datasets/cropped_resized_images_relative_128_cleaned.npy')
stimulus_latent_data_normalized, data_min, data_max = normalize_to_01(torch.Tensor(stimulus_latent_data))
stimulus_latent_data_normalized = stimulus_latent_data_normalized.reshape(-1, 128, 128)
stimulus_latent_data_normalized = np.expand_dims(stimulus_latent_data_normalized, axis=1)

# -----------------------------
# Target group settings
# -----------------------------
sign = 1
group_size = 6
group_index = 2
activated_group = np.arange(group_index * group_size, (group_index + 1) * group_size)
print("Activated group:", activated_group)

# -----------------------------
# Initialize diffusion model
# -----------------------------
timesteps = 150

model = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    flash_attn=False
)

diffusion = GaussianDiffusion(
    model,
    height=128,
    width=128,
    timesteps=timesteps,
    guidance_scale=sign * guidance_scale,
    target_neuron_indices=activated_group,
    positive_probing=False
)

trainer = Trainer(
    diffusion,
    stimulus_latent_data_normalized,
    train_batch_size=64,
    train_lr=2e-4,
    train_num_steps=20000,
    gradient_accumulate_every=2,
    ema_decay=0.98,
    amp=True,
    calculate_fid=False
)

trainer.load("20")

# -----------------------------
# Sampling with CLG-DDIM
# -----------------------------
sample_index = 31
edit_t = timesteps - 1 - 5

x_start = torch.Tensor(stimulus_latent_data_normalized[sample_index]).unsqueeze(0)
x_start = diffusion.normalize(x_start)

set_random_seed(2025)
xt = diffusion.q_sample_ddim(x_start=x_start, t=edit_t)
xt = xt.detach().clone().repeat(9, 1, 1, 1).to(diffusion.device)

sampled_image = diffusion.ddim_xt_p_sample(xt, xt.shape, edit_t=edit_t)
sampled_image = sampled_image.reshape(-1, 128, 128)
x_start = x_start.reshape(-1, 128, 128)

# -----------------------------
# Denormalize and save results
# -----------------------------
sampled_images_denorm = sampled_image * (data_max - data_min) + data_min
xt_denorm = xt * (data_max - data_min) + data_min

# torch.save(xt_denorm, f'samples/xt_denorm_{guidance_scale}_id_{sample_index}_timestep_{edit_t}.pt')
torch.save(sampled_images_denorm, f'samples/ddim_{guidance_scale}_gid_{group_index}_id_{sample_index}_timestep_{sign}_{edit_t}_{generation_index}.pt')
