import os
import numpy as np
import torch
from tqdm import tqdm
from diffusers import UNet2DModel, DDPMScheduler

import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
import model_sample_sensitivity.sample_sensitivity as sensitivity
import utils.loaders as loaders

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.multiprocessing as mp

# User-specified paths
EPS_MODEL_OLD_PATH = ""
EPS_MODEL_NEW_PATH = ""
OLD_MAN_INDICES_PATH = ""
CORR_SAVE_PATH = ""
BATCH_SIZE = 128  # total number of noise samples
NUM_GPUS = torch.cuda.device_count()

# Model loading helper
def load_ddpm_sensitivity(run_dir, device, epoch=None):
    (eps_model, scheduler), (_, _) = loaders.load_models(run_dir, run_dir, dataset="celeba", epoch=epoch)
    eps_model = eps_model.to(device)
    return sensitivity.DDPMSampleSensitivity(
        eps_model,
        scheduler,
        num_hutchinson_samples=1,
        ode_step_size=1e-3,
        min_clamp=1e-1,
        max_clamp=1e1
    )

def worker(rank, z1_chunk, old_man_imgs, result_queue):
    device = f"cuda:{rank}" if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device(rank)
    print(f"[GPU {rank}] Loading models...")
    ddpm_old = load_ddpm_sensitivity(EPS_MODEL_OLD_PATH, device)
    ddpm_new = load_ddpm_sensitivity(EPS_MODEL_NEW_PATH, device, epoch=199)
    z1_chunk = z1_chunk.to(device)
    old_man_imgs = old_man_imgs.to(device)
    print(f"[GPU {rank}] Computing sample paths...")
    zt_old, logpt_old = ddpm_old.precompute_sample_path(z1_chunk)
    zt_new = ddpm_new.precompute_sample_path_no_logpt(z1_chunk)
    print(f"[GPU {rank}] Computing sensitivities and correlations...")
    et = ddpm_old.sensitivity_given_sample_path(zt_old, logpt_old, old_man_imgs)
    samples_old = zt_old[:,-1,:,:,:]
    samples_new = zt_new[:,-1,:,:,:]
    perturbation = et[:,-1,:,:,:]
    diff = samples_new - samples_old
    diff_flat = diff.view(diff.shape[0], -1).cpu().numpy()
    perturbation_flat = perturbation.view(perturbation.shape[0], -1).cpu().numpy()
    correlations = np.array([
        np.corrcoef(diff_flat[i], perturbation_flat[i])[0,1]
        for i in range(diff_flat.shape[0])
    ])
    result_queue.put(correlations)

def main():
    # Load old man images
    transform = transforms.Compose([
        transforms.CenterCrop(140),
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    celeba_dataset = datasets.CelebA(
        root="/nobackup/users/scarv/data",
        split="train",
        download=False,
        transform=transform
    )
    train_indices = np.load(OLD_MAN_INDICES_PATH)
    train_indices = list(set(train_indices.tolist()))
    celeba_subset = torch.utils.data.Subset(celeba_dataset, train_indices)
    old_man_imgs = torch.stack([celeba_subset[i][0] for i in range(len(celeba_subset))], dim=0)

    # Draw BATCH_SIZE Gaussian noise samples
    z1 = torch.randn((BATCH_SIZE, 3, 64, 64))
    # Split z1 across GPUs
    z1_chunks = torch.chunk(z1, NUM_GPUS)

    mp.set_start_method('spawn', force=True)
    result_queue = mp.Queue()
    procs = []
    for rank in range(NUM_GPUS):
        proc = mp.Process(target=worker, args=(rank, z1_chunks[rank], old_man_imgs, result_queue))
        proc.start()
        procs.append(proc)
    all_corrs = []
    for _ in range(NUM_GPUS):
        all_corrs.append(result_queue.get())
    for proc in procs:
        proc.join()
    correlations = np.concatenate(all_corrs)
    np.save(CORR_SAVE_PATH, correlations)
    print(f"Saved correlations to {CORR_SAVE_PATH}. Median correlation: {np.median(correlations):.4f}")

if __name__ == "__main__":
    main()
