import os
import numpy as np
import torch
from tqdm import tqdm
from diffusers import UNet2DModel, DDPMScheduler

import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
import model_sample_sensitivity.sample_sensitivity as sensitivity
import utils.loaders as loaders

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.multiprocessing as mp

# User-specified paths
EPS_MODEL_OLD_PATH = ""
EPS_MODEL_NEW_PATH = ""
CORR_SAVE_PATH = ""
BATCH_SIZE = 256  # total number of noise samples
NUM_GPUS = torch.cuda.device_count()

# Model loading helper
def load_ddpm_sensitivity(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt['state_dict']
    # Remove potential "model." prefix from keys
    fixed_state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()}
    eps_model = UNet2DModel(
            sample_size=28,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(32, 64, 128),  # 3 blocks
            down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
            norm_num_groups=8
        )
    eps_model.load_state_dict(fixed_state_dict)
    eps_model = eps_model.to(device)
    scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            beta_schedule="linear"
        )
    return sensitivity.DDPMSampleSensitivity(
        eps_model,
        scheduler,
        num_hutchinson_samples=1,
        ode_step_size=1e-3,
        min_clamp=1e-1,
        max_clamp=1e1
    )

def worker(rank, z1_chunk, subclass_imgs, result_queue):
    device = f"cuda:{rank}" if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device(rank)
    print(f"[GPU {rank}] Loading models...")
    ddpm_old = load_ddpm_sensitivity(EPS_MODEL_OLD_PATH, device)
    ddpm_new = load_ddpm_sensitivity(EPS_MODEL_NEW_PATH, device)
    z1_chunk = z1_chunk.to(device)
    subclass_imgs = subclass_imgs.to(device)
    print(f"[GPU {rank}] Processing {z1_chunk.shape[0]} samples.")
    print(f"[GPU {rank}] Computing sample paths...")
    zt_old, logpt_old = ddpm_old.precompute_sample_path(z1_chunk)
    zt_new = ddpm_new.precompute_sample_path_no_logpt(z1_chunk)
    print(f"[GPU {rank}] Computing sensitivities and correlations...")
    et = ddpm_old.sensitivity_given_sample_path(zt_old, logpt_old, subclass_imgs)
    samples_old = zt_old[:,-1,:,:,:]
    samples_new = zt_new[:,-1,:,:,:]
    perturbation = et[:,-1,:,:,:]
    diff = samples_new - samples_old
    diff_flat = diff.view(diff.shape[0], -1).cpu().numpy()
    perturbation_flat = perturbation.view(perturbation.shape[0], -1).cpu().numpy()
    correlations = np.array([
        np.corrcoef(diff_flat[i], perturbation_flat[i])[0,1]
        for i in range(diff_flat.shape[0])
    ])
    result_queue.put(correlations)

def main():
    # Load subclass samples
    transform = transforms.Compose([
    transforms.ToTensor()])

    # Load TMNIST samples
    tmnist_path = ""
    tmnist_ims = torch.load(tmnist_path)

    # Draw BATCH_SIZE Gaussian noise samples
    z1 = torch.randn((BATCH_SIZE, 1, 28, 28))
    # Split z1 across GPUs
    z1_chunks = torch.chunk(z1, NUM_GPUS)

    mp.set_start_method('spawn', force=True)
    result_queue = mp.Queue()
    procs = []
    for rank in range(NUM_GPUS):
        proc = mp.Process(target=worker, args=(rank, z1_chunks[rank], tmnist_ims, result_queue))
        proc.start()
        procs.append(proc)
    all_corrs = []
    for _ in range(NUM_GPUS):
        all_corrs.append(result_queue.get())
    for proc in procs:
        proc.join()
    correlations = np.concatenate(all_corrs)
    np.save(CORR_SAVE_PATH, correlations)
    print(f"Saved correlations to {CORR_SAVE_PATH}. Median correlation: {np.median(correlations):.4f}")

if __name__ == "__main__":
    main()
