import os
import json
import torch
import numpy as np
from diffusers import UNet2DModel, DDPMScheduler
from torchvision import datasets, transforms
from torch.utils.data import Subset
import re

import sys
# APPEND PATH TO THE PROJECT CODE TO ENABLE IMPORTS
from utils.cfdm_ddpm_conversion import ScoreFromEps
from closed_form_diffusion.cfdm import CFDM, CFDM_NN

def load_unet_from_lightning_ckpt(ckpt_path, dataset="cifar10"):
    """
    Loads a Diffusers UNet2DModel from a Lightning checkpoint.
    Assumes the checkpoint only contains model weights under 'state_dict'.
    """
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt["state_dict"]

    # Remove potential "model." prefix from keys
    new_state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()}

    if dataset.lower() == "celeba":
        model = UNet2DModel(
            sample_size=64,
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(128, 256, 512, 512),  # Increased for 64x64 images
            norm_num_groups=32
        )
    elif dataset.lower() == "cifar10" or dataset.lower() == "cifar10_iid":
        model = UNet2DModel(
            sample_size=32,
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512),
            norm_num_groups=32,
            )
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")
    model.load_state_dict(new_state_dict)
    return model

def find_latest_ckpt(run_dir):
    # List all files matching epoch=*.ckpt
    ckpt_files = [f for f in os.listdir(run_dir) if re.match(r'epoch=\d+\.ckpt', f)]
    if not ckpt_files:
        raise FileNotFoundError(f"No checkpoint files found in {run_dir}")
    # Extract epoch numbers and find the largest
    epochs = [int(re.findall(r'epoch=(\d+)\.ckpt', f)[0]) for f in ckpt_files]
    max_epoch = max(epochs)
    ckpt_path = os.path.join(run_dir, f"epoch={max_epoch}.ckpt")
    return ckpt_path

def load_models(run_dir_subset, run_dir_full, dataset="cifar10", epoch=None):
    """
    Loads:
    - UNet2DModel trained on subset (epsilon predictor)
    - UNet2DModel trained on full data -> converted to score model via ScoreFromEps
    Returns:
    ((subset_model, scheduler), (score_model, scheduler))
    """
    if epoch is not None:
        ckpt_path_subset = os.path.join(run_dir_subset, f"epoch={epoch}.ckpt")
        ckpt_path_full = os.path.join(run_dir_full, f"epoch={epoch}.ckpt")
        print(f"Loading subset model from {ckpt_path_subset} and full model from {ckpt_path_full}")
    else:
        ckpt_path_subset = find_latest_ckpt(run_dir_subset)
        ckpt_path_full = find_latest_ckpt(run_dir_full)

    subset_model = load_unet_from_lightning_ckpt(ckpt_path_subset, dataset)
    full_model = load_unet_from_lightning_ckpt(ckpt_path_full, dataset)

    scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        beta_schedule="linear"
    )

    score_model = ScoreFromEps(full_model, scheduler)

    return (subset_model, scheduler), (score_model, scheduler)


def build_cfdm_from_subset_metadata(run_dir_subset, 
                                    dataset="cifar10", 
                                    device_id=0, 
                                    return_train_or_complement_samples=None,
                                    num_dct_coeffs=None
                                    ):
    """
    Given a subset model's run_dir and the full training data,
    returns a CFDM built from the complement of the subset data.

    Args:
        run_dir_subset: path to the run directory of the subset-trained model
        dataset: "celeba" or "cifar10"
        device_id: GPU device ID to use
        return_train_or_complement_samples: "train" to return training samples, "complement" to return complement samples
        num_dct_coeffs: Number of DCT coefficients to use. If None, performs no DCT compression.

    Returns:
        CFDM instance acting as s1
    """
    # Load selected indices or classes
    if dataset.lower() == "celeba":
        index_path = os.path.join(run_dir_subset, "image_indices.npy")
        selected_indices = np.load(index_path)

        transform = transforms.Compose([
            transforms.CenterCrop(140),
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        full_dataset = datasets.CelebA(
            root="/home/ec2-user/data",
            split="train",
            download=False,
            transform=transform
        )
        total_indices = set(range(len(full_dataset)))
        subset_indices = set(selected_indices.tolist())
        complement_indices = sorted(list(total_indices - subset_indices))

        complement_dataset = Subset(full_dataset, complement_indices)

        loader = torch.utils.data.DataLoader(complement_dataset, batch_size=len(complement_dataset), shuffle=False)
        complement_samples = next(iter(loader))[0]
        # complement_samples = all_data[complement_indices] # (len(complement_indices), 3, 64, 64)

    elif dataset.lower() == "cifar10":
        config_path = os.path.join(run_dir_subset, "config.json")
        with open(config_path, "r") as f:
            config = json.load(f)
        selected_classes = config["selected_classes"]

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        full_dataset = datasets.CIFAR10(
            root="/home/ec2-user/data-attribution/diffusion_trainers/data",
            train=True,
            download=False,
            transform=transform
        )

        class_to_idx = {cls: idx for idx, cls in enumerate(full_dataset.classes)}
        selected_class_ids = set(class_to_idx[cls] for cls in selected_classes)
        subset_indices = [i for i, (_, label) in enumerate(full_dataset) if label in selected_class_ids]
        complement_indices = [i for i, (_, label) in enumerate(full_dataset) if label not in selected_class_ids]

        loader = torch.utils.data.DataLoader(full_dataset, batch_size=len(full_dataset), shuffle=False)
        all_data, _ = next(iter(loader))
        complement_samples = all_data[complement_indices] # (len(complement_indices), 3, 32, 32)
    elif dataset.lower() == "cifar10_iid":
        # index_path is of the form image_indices_0.npy, image_indices_1.npy, ..., where 0, 1, ... are the final digit of the run_dir
        # Extract the run ID from run_dir_subset
        run_id = run_dir_subset.split("_")[-1]
        # Remove zero-padding
        last_digit = int(run_id)
        index_path = os.path.join(run_dir_subset, f"image_indices_{last_digit}.npy")
        selected_indices = np.load(index_path)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        full_dataset = datasets.CIFAR10(
            root="/home/ec2-user/data-attribution/diffusion_trainers/data",
            train=True,
            download=False,
            transform=transform
        )
        total_indices = set(range(len(full_dataset)))
        subset_indices = set(selected_indices.tolist())
        complement_indices = sorted(list(total_indices - subset_indices))

        complement_dataset = Subset(full_dataset, complement_indices)

        loader = torch.utils.data.DataLoader(complement_dataset, batch_size=len(complement_dataset), shuffle=False)
        complement_samples = next(iter(loader))[0]
        # complement_samples = all_data[complement_indices] # (len(complement_indices), 3, 32, 32)

    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    # Define scheduler (same as in load_models)
    scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        beta_schedule="linear"
    )

    complement_samples = complement_samples.to(f"cuda:{device_id}")
    if return_train_or_complement_samples == "train":
        train_samples = all_data[subset_indices].to(f"cuda:{device_id}")
        return CFDM(complement_samples, scheduler), train_samples
    elif return_train_or_complement_samples == "complement":
        return CFDM(complement_samples, scheduler), complement_samples
    else:
        return CFDM(complement_samples, scheduler, num_dct_coeffs=num_dct_coeffs)