import os
import numpy as np
import torch
from tqdm import tqdm
from diffusers import UNet2DModel, DDPMScheduler
import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
import model_sample_sensitivity.sample_sensitivity as sensitivity
import utils.loaders as loaders
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import ot

# User-specified paths
EPS_MODEL_OLD_PATH = ""
EPS_MODEL_NEW_PATH = ""
TRAIN_INDICES_PATH = ""
OLD_MAN_INDICES_PATH = ""
CORR_SAVE_PATH = ""
BATCH_SIZE = 256  # total number of noise samples
P = 0.1  # total weight on old_man_indices

# Model loading helper
def load_ddpm_sensitivity(run_dir, device):
    (eps_model, scheduler), (_, _) = loaders.load_models(run_dir, run_dir, dataset="celeba")
    eps_model = eps_model.to(device)
    return sensitivity.DDPMSampleSensitivity(
        eps_model,
        scheduler,
        num_hutchinson_samples=1,
        ode_step_size=1e-3,
        min_clamp=1e-1,
        max_clamp=1e1
    )

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ddpm_old = load_ddpm_sensitivity(EPS_MODEL_OLD_PATH, device)
    ddpm_new = load_ddpm_sensitivity(EPS_MODEL_NEW_PATH, device)

    # Draw BATCH_SIZE Gaussian noise samples
    z1 = torch.randn((BATCH_SIZE, 3, 64, 64)).to(device)
    print("Computing sample paths...")
    with torch.no_grad():
        zt_old = ddpm_old.precompute_sample_path_no_logpt(z1)
        zt_new = ddpm_new.precompute_sample_path_no_logpt(z1)
        samples_old = zt_old[:,-1,:,:,:].cpu()
        samples_new = zt_new[:,-1,:,:,:].cpu()

    # Load CelebA images for train_indices and old_man_indices
    transform = transforms.Compose([
        transforms.CenterCrop(140),
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    celeba_dataset = datasets.CelebA(
        root="", # pass the path to the CelebA dataset
        split="train",
        download=False,
        transform=transform
    )
    train_indices = np.load(TRAIN_INDICES_PATH)
    train_indices = list(set(train_indices.tolist()))
    old_man_indices = np.load(OLD_MAN_INDICES_PATH)
    old_man_indices = list(set(old_man_indices.tolist()))
    celeba_subset = torch.utils.data.Subset(celeba_dataset, train_indices)
    old_man_subset = torch.utils.data.Subset(celeba_dataset, old_man_indices)
    train_imgs = torch.stack([celeba_subset[i][0] for i in range(len(celeba_subset))], dim=0)
    old_man_imgs = torch.stack([old_man_subset[i][0] for i in range(len(old_man_subset))], dim=0)
    mixture_imgs = torch.cat([old_man_imgs, train_imgs], dim=0)

    # Sinkhorn coupling
    print("Computing Sinkhorn coupling and correlations...")
    n_old = samples_old.shape[0]
    n_mixture = mixture_imgs.shape[0]
    # Flatten images for cost computation
    samples_old_flat = samples_old.view(n_old, -1).numpy()
    mixture_imgs_flat = mixture_imgs.view(n_mixture, -1).numpy()
    # Uniform weights within blocks
    a = np.ones(n_old) / n_old
    b = np.concatenate([
        np.ones(len(old_man_imgs)) * (P / len(old_man_imgs)),
        np.ones(len(train_imgs)) * ((1-P) / len(train_imgs))
    ])
    # Compute cost matrix
    cost_matrix = ot.dist(samples_old_flat, mixture_imgs_flat, metric='euclidean')
    # Sinkhorn
    coupling = ot.sinkhorn(a, b, cost_matrix, reg=5e-2, method='sinkhorn_log', numItermax=100000)
    # Compute transport rays
    rays = np.dot(coupling, mixture_imgs_flat) - samples_old_flat  # shape: (n_old, n_pixels)
    # Compute correlation between new_samples - old_samples and transport rays
    diff = samples_new.view(n_old, -1).numpy() - samples_old_flat
    correlations = np.array([
        np.corrcoef(diff[i], rays[i])[0,1]
        for i in range(n_old)
    ])
    np.save(CORR_SAVE_PATH, correlations)
    print(f"Saved correlations to {CORR_SAVE_PATH}. Median correlation: {np.median(correlations):.4f}")

if __name__ == "__main__":
    main()
