import os
import numpy as np
import torch
from tqdm import tqdm
from diffusers import UNet2DModel, DDPMScheduler
import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
import model_sample_sensitivity.sample_sensitivity as sensitivity
import utils.loaders as loaders
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import ot

# User-specified paths
EPS_MODEL_OLD_PATH = ""
EPS_MODEL_NEW_PATH = ""
CORR_SAVE_PATH = ""
BATCH_SIZE = 256  # total number of noise samples
P = 1.0  # total weight on old_man_indices

# Model loading helper
def load_ddpm_sensitivity(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt['state_dict']
    # Remove potential "model." prefix from keys
    fixed_state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()}
    eps_model = UNet2DModel(
            sample_size=28,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(32, 64, 128),  # 3 blocks
            down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
            norm_num_groups=8
        )
    eps_model.load_state_dict(fixed_state_dict)
    eps_model = eps_model.to(device)
    scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            beta_schedule="linear"
        )
    return sensitivity.DDPMSampleSensitivity(
        eps_model,
        scheduler,
        num_hutchinson_samples=1,
        ode_step_size=1e-2,
        min_clamp=1e-1,
        max_clamp=1e1
    )

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ddpm_old = load_ddpm_sensitivity(EPS_MODEL_OLD_PATH, device)
    ddpm_new = load_ddpm_sensitivity(EPS_MODEL_NEW_PATH, device)

    # Draw BATCH_SIZE Gaussian noise samples
    z1 = torch.randn((BATCH_SIZE, 1, 28, 28)).to(device)
    print("Computing sample paths...")
    with torch.no_grad():
        zt_old = ddpm_old.precompute_sample_path_no_logpt(z1)
        zt_new = ddpm_new.precompute_sample_path_no_logpt(z1)
        samples_old = zt_old[:,-1,:,:,:].cpu()
        samples_new = zt_new[:,-1,:,:,:].cpu()

    # Load samples
    transform = transforms.Compose([
    transforms.ToTensor()])

    mnist = datasets.MNIST(
                root="", # PATH TO MNIST DATASET
                train=True,
                download=True,
                transform=transforms.ToTensor()
            )

    # Load TMNIST samples
    tmnist_path = ""
    tmnist_ims = torch.load(tmnist_path)
    train_imgs = torch.stack([img for img, label in mnist], dim=0)
    mixture_imgs = torch.cat([tmnist_ims, train_imgs], dim=0)

    # Sinkhorn coupling
    print("Computing Sinkhorn coupling and correlations...")
    n_old = samples_old.shape[0]
    n_mixture = mixture_imgs.shape[0]
    # Flatten images for cost computation
    samples_old_flat = samples_old.view(n_old, -1).numpy()
    mixture_imgs_flat = mixture_imgs.view(n_mixture, -1).numpy()
    # Uniform weights within blocks
    a = np.ones(n_old) / n_old
    b = np.concatenate([
        np.ones(len(tmnist_ims)) * (P / len(tmnist_ims)),
        np.ones(len(train_imgs)) * ((1-P) / len(train_imgs))
    ])
    # Compute cost matrix
    cost_matrix = ot.dist(samples_old_flat, mixture_imgs_flat, metric='euclidean')
    # Sinkhorn
    coupling = ot.sinkhorn(a, b, cost_matrix, reg=5e-2, method='sinkhorn_log', numItermax=100000)
    # Compute transport rays
    rays = np.dot(coupling, mixture_imgs_flat) - samples_old_flat  # shape: (n_old, n_pixels)
    # Compute correlation between new_samples - old_samples and transport rays
    diff = samples_new.view(n_old, -1).numpy() - samples_old_flat
    correlations = np.array([
        np.corrcoef(diff[i], rays[i])[0,1]
        for i in range(n_old)
    ])
    np.save(CORR_SAVE_PATH, correlations)
    print(f"Saved correlations to {CORR_SAVE_PATH}. Median correlation: {np.median(correlations):.4f}")

if __name__ == "__main__":
    main()
