"""
MorphGen Comparison Script

This script compares:
- Real images: From CellDataModule (U2OS cell line)
- Generated images: MorphGen generated images

The script uses the same 100 perturbations approach as the CellFlux comparison:
1. Gets siRNA IDs from the compound-to-siRNA mapping
2. Filters for U2OS cell line
3. Generates crops from real images using nuclei-centered cropping
4. Compares against MorphGen generated images
5. Calculates conditional and unconditional FID scores

KEY FEATURES:
1. **Fixed Target Distribution**: For each perturbation, we preselect a fixed subset of real images
   that serves as our consistent target distribution across all FID calculations.
2. **Consistent Sample Sizes**: All FID calculations use the same fixed real subsets, ensuring
   reproducible and fair comparisons.
3. **Proper Generated Sampling**: Generated images are sampled up to the size of half the fixed
   real subset, maintaining fairness without reducing the target distribution size.

Usage:
    python perturbation_type_morphgen_on_crops.py --num_perturbations 100
"""

import argparse
import glob
import logging
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch_fidelity
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything

# Import from your modules
from sc_perturb.dataset import CellDataModule, to_rgb
from tqdm import tqdm

# =========================
# Constants for cropping pipeline
# =========================
NUCLEUS_CHANNEL = 0  # nucleus channel index
CROP_SIZE = 96
BORDER_MARGIN = CROP_SIZE // 2
MIN_FG_FRAC = 0.005  # min fraction of nucleus-foreground pixels in crop
MIN_NUCLEUS_MEAN = 0.01  # min mean intensity of nucleus channel in crop
MAX_TRIALS_FACTOR = 100  # we allow up to factor * requested samples attempts


class CustomDataset(torch.utils.data.Dataset):
    """Dataset wrapper for torch_fidelity"""

    def __init__(self, data_4d):
        self.data = data_4d

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]


def normalize_image_to_01(img: torch.Tensor) -> torch.Tensor:
    """Normalize image per-image to [0,1] range for better illumination matching."""
    img_min = img.min()
    img_max = img.max()
    if img_max > img_min:
        normalized = (img - img_min) / (img_max - img_min)
    else:
        normalized = img  # Constant image, keep as is
    return normalized


def safe_load_npy(file_path: str) -> Optional[torch.Tensor]:
    """Load .npy as torch.FloatTensor [C,H,W] scaled to [0,1]."""
    try:
        arr = np.load(file_path)
        ten = torch.from_numpy(arr)
        if ten.ndim == 2:  # [H,W] -> [1,H,W]
            ten = ten.unsqueeze(0)
        elif ten.ndim == 3:
            # If last dim looks like channels, convert HWC->CHW
            if ten.shape[-1] in (1, 3, 4, 5, 6):
                ten = ten.permute(2, 0, 1).contiguous()

        ten = ten.float()
        # Normalize to [0,1] if needed
        mx = ten.max().item() if ten.numel() > 0 else 1.0
        if mx > 1.5:
            ten = ten / mx
        return ten.clamp(0, 1)
    except Exception as e:
        print(f"[warn] failed to load {file_path}: {e}")
        return None


def safe_load_png(file_path: str) -> Optional[torch.Tensor]:
    """Load .png as torch.FloatTensor [C,H,W] scaled to [0,1]."""
    try:
        img = Image.open(file_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        ten = TF.to_tensor(img)  # [C, H, W] in range [0, 1]
        return ten
    except Exception as e:
        print(f"[warn] failed to load {file_path}: {e}")
        return None


def get_cellflux_compounds(cellflux_path: str) -> List[str]:
    """Get the 100 compound names from CellFlux directory."""
    try:
        if not os.path.exists(cellflux_path):
            print(f"[error] CellFlux path does not exist: {cellflux_path}")
            return []

        # Get all directories (compound names) from CellFlux
        compounds = [
            d
            for d in os.listdir(cellflux_path)
            if os.path.isdir(os.path.join(cellflux_path, d))
        ]

        print(f"Found {len(compounds)} compounds in CellFlux directory")
        print(f"First 10 compounds: {compounds[:10]}")

        return sorted(compounds)

    except Exception as e:
        print(f"[error] Failed to read CellFlux directory: {e}")
        return []


def get_available_compounds(
    metadata_path: str, cellflux_compounds: List[str], cell_line: str = "U2OS"
) -> List[str]:
    """Filter CellFlux compounds to only those present in metadata."""
    try:
        df = pd.read_csv(metadata_path)
        metadata_compounds = set(
            df[(df["CELL_LINE"] == cell_line) & (df["ANNOT"] == "treated")][
                "CPD_NAME"
            ].unique()
        )

        # Keep only CellFlux compounds that exist in metadata
        available_compounds = [
            compound
            for compound in cellflux_compounds
            if compound in metadata_compounds
        ]

        print(
            f"Found {len(available_compounds)} CellFlux compounds that exist in metadata for {cell_line} treated samples"
        )

        missing_compounds = [
            compound
            for compound in cellflux_compounds
            if compound not in metadata_compounds
        ]

        if missing_compounds:
            print(
                f"[warn] {len(missing_compounds)} CellFlux compounds not found in metadata: {missing_compounds[:5]}"
            )

        return available_compounds

    except Exception as e:
        print(f"[error] Failed to load metadata: {e}")
        return []


def create_compound_to_sirna_mapping(
    datamodule: CellDataModule,
    metadata_df: pd.DataFrame,
    available_compounds: List[str],
    cell_line: str = "U2OS",
) -> Dict[str, Dict]:
    """Create mapping from compound names to siRNA IDs using sample key matching."""
    print("Creating compound to siRNA mapping...")

    # Get MorphGen metadata
    morphgen_df = datamodule.metadata

    # Filter for the specified cell line
    morphgen_cellline = morphgen_df[morphgen_df["cell_type"] == cell_line]
    print(f"Found {len(morphgen_cellline)} {cell_line} samples in MorphGen")

    # Filter IMPA metadata for treated samples with available compounds
    impa_filtered = metadata_df[
        (metadata_df["CPD_NAME"].isin(available_compounds))
        & (metadata_df["ANNOT"] == "treated")
        & (metadata_df["CELL_LINE"] == cell_line)
    ]
    print(
        f"Found {len(impa_filtered)} IMPA records for available compounds ({cell_line}, treated)"
    )

    # Create mapping by matching sample keys
    compound_to_sirna_mapping = {}
    successful_matches = 0

    print("Matching sample keys between MorphGen and IMPA...")

    # Group IMPA data by base key (without frame number) for efficiency
    impa_base_keys = {}
    for _, row in impa_filtered.iterrows():
        sample_key = row["SAMPLE_KEY"]
        compound = row["CPD_NAME"]

        # Parse SAMPLE_KEY: e.g., "U2OS-01_1_B02_s1_14"
        # Extract base: "U2OS-01_1_B02_s1"
        parts = sample_key.split("_")
        if len(parts) >= 4:
            base_key = "_".join(parts[:-1])  # Remove frame number
            if base_key not in impa_base_keys:
                impa_base_keys[base_key] = []
            impa_base_keys[base_key].append(compound)

    print(f"Created {len(impa_base_keys)} unique base keys from IMPA data")

    # Match with MorphGen data
    for _, morphgen_row in morphgen_cellline.iterrows():
        site_id = morphgen_row["site_id"]
        sirna = morphgen_row["sirna"]
        sirna_id = morphgen_row["sirna_id"]

        # Convert MorphGen site_id to IMPA base key format
        # MorphGen: "U2OS-01_1_B02_1" → IMPA base: "U2OS-01_1_B02_s1"
        parts = site_id.split("_")
        if len(parts) >= 4:
            # Construct IMPA-style base key
            experiment = parts[0]  # "U2OS-01"
            plate = parts[1]  # "1"
            well = parts[2]  # "B02"
            site = parts[3]  # "1"

            impa_base_key = f"{experiment}_{plate}_{well}_s{site}"

            # Check if this base key exists in IMPA data
            if impa_base_key in impa_base_keys:
                compounds = impa_base_keys[impa_base_key]

                # All compounds at this location should be the same, take the first one
                compound = compounds[0]

                if compound in available_compounds:
                    if compound not in compound_to_sirna_mapping:
                        compound_to_sirna_mapping[compound] = {
                            "sirna": sirna,
                            "sirna_id": sirna_id,
                            "sample_count": 0,
                        }
                    compound_to_sirna_mapping[compound]["sample_count"] += 1
                    successful_matches += 1

    print(
        f"Successfully created mappings for {len(compound_to_sirna_mapping)} compounds"
    )
    print(f"Total matched samples: {successful_matches}")

    # Validate that we have exactly the expected number of compounds
    if len(compound_to_sirna_mapping) != len(available_compounds):
        missing_compounds = set(available_compounds) - set(
            compound_to_sirna_mapping.keys()
        )
        print(
            f"[warn] Missing mappings for {len(missing_compounds)} compounds: {list(missing_compounds)[:10]}"
        )

    # Extract unique siRNA IDs
    unique_sirna_ids = set()
    unique_sirnas = set()
    for compound, mapping in compound_to_sirna_mapping.items():
        unique_sirna_ids.add(mapping["sirna_id"])
        unique_sirnas.add(mapping["sirna"])

    print(f"Mapped to {len(unique_sirna_ids)} unique siRNA IDs")
    print(f"Mapped to {len(unique_sirnas)} unique siRNA names")

    # Show the siRNA ID distribution
    sirna_id_counts = {}
    for compound, mapping in compound_to_sirna_mapping.items():
        sirna_id = mapping["sirna_id"]
        if sirna_id not in sirna_id_counts:
            sirna_id_counts[sirna_id] = []
        sirna_id_counts[sirna_id].append(compound)

    print(f"siRNA ID to compound mapping (first 10):")
    for i, (sirna_id, compounds) in enumerate(sorted(sirna_id_counts.items())[:10]):
        print(f"  siRNA ID {sirna_id}: {compounds}")

    if len(unique_sirna_ids) != 100:
        print(
            f"[warn] Expected 100 unique siRNA IDs for CellFlux compounds, but found {len(unique_sirna_ids)}"
        )
    else:
        print(f"✓ Successfully validated 100 unique siRNA IDs for CellFlux compounds")

    return compound_to_sirna_mapping


def collect_morphgen_generated_images(
    sirna_id: int,
    morphgen_data_path: str,
    cell_line_code: int = 0,  # 0 for U2OS
    max_samples: Optional[int] = None,
) -> List[torch.Tensor]:
    """Collect MorphGen generated NPY images for a siRNA ID."""
    # MorphGen data is organized as p{sirna_id}/p{sirna_id}_c{cell_line_code}_sample{sample_id}.npy
    perturbation_dir = os.path.join(morphgen_data_path, f"p{sirna_id}")
    if not os.path.exists(perturbation_dir):
        return []

    # Find all files for this perturbation and cell line
    pattern = f"p{sirna_id}_c{cell_line_code}_sample*.npy"
    npy_files = glob.glob(os.path.join(perturbation_dir, pattern))

    if max_samples and len(npy_files) > max_samples:
        npy_files = random.sample(npy_files, max_samples)

    images = []
    for file_path in tqdm(
        npy_files, desc=f"Loading MorphGen images for siRNA ID {sirna_id}", leave=False
    ):
        img = safe_load_npy(file_path)
        if img is not None:
            # Normalize each generated image to [0,1] range for better illumination matching
            img_normalized = normalize_image_to_01(img)
            images.append(img_normalized)

    return images


def collect_real_images_from_datamodule(
    sirna_id: int,
    datamodule: CellDataModule,
    cell_line: str = "U2OS",
    max_samples: Optional[int] = None,
) -> List[torch.Tensor]:
    """Collect real images for a siRNA ID from the CellDataModule."""
    # Filter datamodule metadata for this siRNA ID and cell line
    metadata = datamodule.metadata
    sirna_data = metadata[
        (metadata["sirna_id"] == sirna_id) & (metadata["cell_type"] == cell_line)
    ]

    if len(sirna_data) == 0:
        return []

    images = []
    sample_indices = list(sirna_data.index)

    if max_samples and len(sample_indices) > max_samples:
        sample_indices = random.sample(sample_indices, max_samples)

    for idx in tqdm(
        sample_indices, desc=f"Loading real images for siRNA ID {sirna_id}", leave=False
    ):
        try:
            # Use the datamodule's load_sample method to get the image
            image = datamodule.load_sample(idx)  # Returns torch.Tensor [C, H, W]

            # Ensure it's a float tensor in [0,1] range
            if image.dtype != torch.float32:
                image = image.float()

            # Normalize to [0,1] if needed
            if image.max() > 1.5:
                image = image / image.max()

            image = image.clamp(0, 1)
            images.append(image)

        except Exception as e:
            print(f"[warn] Failed to load image at index {idx}: {e}")

    return images


# =========================
# Nuclei-centered cropping (from comparison_2.py)
# =========================
def otsu_threshold_approx(x: torch.Tensor) -> float:
    """Approximate Otsu threshold for x in [0,1]."""
    xv = x.detach().cpu().flatten()
    if xv.numel() == 0:
        return 0.0
    hist = torch.histc(xv, bins=256, min=0.0, max=1.0)
    p = hist / hist.sum().clamp(min=1)
    omega = torch.cumsum(p, dim=0)
    mu = torch.cumsum(p * torch.arange(256, dtype=torch.float32), dim=0)
    mu_t = mu[-1]
    sigma_b2 = (mu_t * omega - mu) ** 2 / (omega * (1.0 - omega)).clamp(min=1e-8)
    sigma_b2[torch.isnan(sigma_b2)] = -1
    k = int(torch.argmax(sigma_b2).item())
    return k / 255.0


def pick_nucleus_center(mask: torch.Tensor, margin: int) -> Optional[Tuple[int, int]]:
    """Pick random nucleus center from mask, respecting border margins."""
    H, W = mask.shape
    mask = mask.clone()
    if margin > 0:
        mask[:margin, :] = 0
        mask[-margin:, :] = 0
        mask[:, :margin] = 0
        mask[:, -margin:] = 0
    coords = mask.nonzero(as_tuple=False)
    if coords.numel() == 0:
        return None
    yx = coords[torch.randint(0, coords.shape[0], (1,)).item()]
    return int(yx[0].item()), int(yx[1].item())


def crop_centered(
    img: torch.Tensor, center: Tuple[int, int], size: int
) -> torch.Tensor:
    """Crop image centered at given point with reflection padding if needed."""
    C, H, W = img.shape
    half = size // 2
    y, x = center
    y0 = max(0, y - half)
    x0 = max(0, x - half)
    y1 = min(H, y0 + size)
    x1 = min(W, x0 + size)
    crop = img[:, y0:y1, x0:x1]

    # Pad if near borders
    pad_h = size - crop.shape[1]
    pad_w = size - crop.shape[2]
    if pad_h > 0 or pad_w > 0:
        crop = F.pad(crop, (0, pad_w, 0, pad_h), mode="reflect")
    return crop


def nuclei_centered_crop(
    img: torch.Tensor,
    nucleus_channel: int = NUCLEUS_CHANNEL,
    crop_size: int = CROP_SIZE,
    border_margin: int = BORDER_MARGIN,
    min_fg_frac: float = MIN_FG_FRAC,
    min_nucleus_mean: float = MIN_NUCLEUS_MEAN,
) -> Optional[torch.Tensor]:
    """Extract nucleus-centered crop from image."""
    assert img.ndim == 3, "Expected [C,H,W]"
    nuc = img[nucleus_channel].clamp(0, 1)
    thr = otsu_threshold_approx(nuc)
    mask = nuc >= thr

    center = pick_nucleus_center(mask, border_margin)
    if center is None:
        return None

    crop = crop_centered(img, center, crop_size)
    nuc_crop = crop[nucleus_channel]
    fg_frac = (nuc_crop >= thr).float().mean().item()
    if fg_frac < min_fg_frac:
        return None
    if nuc_crop.mean().item() < min_nucleus_mean:
        return None
    return crop


def sample_nuclei_crops(
    images: List[torch.Tensor],
    num_samples: int,
    max_trials_factor: int = MAX_TRIALS_FACTOR,
) -> torch.Tensor:
    """Sample nucleus-centered crops from image list."""
    if len(images) == 0:
        raise ValueError("Empty image pool.")

    crops = []
    attempts = 0
    max_attempts = num_samples * max_trials_factor

    while len(crops) < num_samples and attempts < max_attempts:
        attempts += 1
        img = images[random.randrange(len(images))]
        crop = nuclei_centered_crop(img)
        if crop is not None:
            crops.append(crop)

    if len(crops) < num_samples:
        print(
            f"[warn] Only collected {len(crops)}/{num_samples} valid crops after {attempts} attempts."
        )

    return torch.stack(crops) if len(crops) > 0 else torch.empty(0)


def to_rgb_batch(x: torch.Tensor) -> torch.Tensor:
    """Convert batch to RGB using the to_rgb function from dataset."""
    if x.ndim == 3:
        rgb = to_rgb(x[None].cpu()).squeeze(0)  # [3,H,W]
        return rgb[None]
    elif x.ndim == 4:
        rgbs = []
        for i in range(x.shape[0]):
            rgb = to_rgb(x[i][None].cpu()).squeeze(0)  # [3,H,W]
            rgbs.append(rgb)
        return torch.stack(rgbs, 0)
    else:
        raise ValueError(f"Unexpected tensor shape: {x.shape}")


def to_eval_rgb_uint8(batch_chw: torch.Tensor) -> torch.Tensor:
    """Convert to RGB uint8 for FID evaluation."""
    rgb = to_rgb_batch(batch_chw)  # float [0,1]
    rgb = (rgb.clamp(0, 1) * 255.0).to(torch.uint8)
    return rgb


def preselect_target_real_images(
    sirna_id: int,
    real_images: List[torch.Tensor],
    min_samples: int = 32,
    max_samples: Optional[int] = None,
    seed: int = 42,
) -> List[torch.Tensor]:
    """
    Preselect a fixed subset of real images to serve as target distribution.
    This subset will be used consistently across all FID calculations.
    """
    if len(real_images) < min_samples:
        return real_images

    # Set seed for reproducible selection
    random.seed(
        seed + hash(sirna_id) % 1000
    )  # Make seed sirna_id-specific but reproducible

    # Determine target size
    if max_samples is None:
        target_size = len(real_images)
    else:
        target_size = min(max_samples, len(real_images))

    # Ensure even number for clean splitting
    if target_size % 2 == 1:
        target_size -= 1

    # Select subset
    selected_images = random.sample(real_images, target_size)

    print(
        f"[siRNA {sirna_id}] Selected {len(selected_images)} real images as fixed target distribution"
    )

    # Reset random seed
    random.seed()

    return selected_images


def calculate_fid_for_sirna(
    sirna_id: int,
    fixed_real_images: List[torch.Tensor],  # Now expects pre-selected fixed subset
    generated_images: List[torch.Tensor],
    min_samples: int = 32,
) -> Optional[Dict]:
    """Calculate FID metrics for a single siRNA using fixed real image subset."""
    # Since we generate multiple crops per image, we need fewer raw images
    min_raw_images = min_samples // 2
    # if len(fixed_real_images) < min_raw_images:
    #     print(
    #         f"[warn] Insufficient fixed real samples for siRNA {sirna_id}: {len(fixed_real_images)}, need at least {min_raw_images} (will generate multiple crops per image)"
    #     )
    #     return None

    # Use the fixed real images (already preselected) to generate target number of crops
    n_raw_images = len(fixed_real_images)
    # Target generating min_samples crops, but at least use all available raw images
    n_real_total = max(min_samples, n_raw_images)
    # Ensure even number for clean splitting
    if n_real_total % 2 == 1:
        n_real_total -= 1
    n_real_per_half = n_real_total // 2

    # For generated images, use as many as available up to n_real_per_half
    n_gen_to_use = min(len(generated_images), n_real_per_half)

    print(
        f"Debug: Processing siRNA {sirna_id} with {n_raw_images} raw images, targeting {n_real_total} crops total"
    )
    print(
        f"Debug: Will generate {n_real_per_half} real crops per half, {n_gen_to_use} generated crops"
    )

    if len(fixed_real_images) > 0:
        print(
            f"Debug: Fixed real image shape: {fixed_real_images[0].shape}, range: [{fixed_real_images[0].min():.4f}, {fixed_real_images[0].max():.4f}]"
        )
    if len(generated_images) > 0:
        print(
            f"Debug: Generated image shape: {generated_images[0].shape}, range: [{generated_images[0].min():.4f}, {generated_images[0].max():.4f}]"
        )

    # Sample crops from both real and generated images
    real_crops = sample_nuclei_crops(fixed_real_images, n_real_total)
    if real_crops.shape[0] < n_real_total:
        print(
            f"[warn] Could only generate {real_crops.shape[0]} valid real crops for siRNA {sirna_id}, expected {n_real_total}"
        )
        n_real_total = real_crops.shape[0]
        n_real_per_half = n_real_total // 2

    generated_crops = sample_nuclei_crops(generated_images, n_gen_to_use)
    if generated_crops.shape[0] < n_gen_to_use:
        print(
            f"[warn] Could only generate {generated_crops.shape[0]} valid generated crops for siRNA {sirna_id}, expected {n_gen_to_use}"
        )
        n_gen_to_use = generated_crops.shape[0]

    # Convert to RGB
    real_rgb = to_eval_rgb_uint8(real_crops)
    generated_rgb = to_eval_rgb_uint8(generated_crops)

    # Split fixed real images into two equal halves for real vs real comparison
    real_rgb_half1 = real_rgb[:n_real_per_half]
    real_rgb_half2 = real_rgb[n_real_per_half : 2 * n_real_per_half]

    # Use the actual number of generated images we have (up to n_real_per_half)
    gen_rgb_final = generated_rgb[:n_gen_to_use]

    print(
        f"Debug: Final sizes - real_half1: {real_rgb_half1.shape[0]}, real_half2: {real_rgb_half2.shape[0]}, gen: {gen_rgb_final.shape[0]}"
    )

    results = {
        "sirna_id": sirna_id,
        "num_real_raw_images": n_raw_images,
        "num_real_crops_generated": n_real_total,
        "num_real_per_half": n_real_per_half,
        "num_generated_available": len(generated_images),
        "num_generated_used": n_gen_to_use,
    }

    try:
        # Real vs Generated FID (use n_real_per_half vs n_gen_to_use)
        real_dataset = CustomDataset(real_rgb_half1)
        gen_dataset = CustomDataset(gen_rgb_final)

        metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset,
            input2=gen_dataset,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_generated_fid"] = float(metrics["frechet_inception_distance"])

        # Real vs Real FID (using both halves of fixed real images)
        real_dataset1 = CustomDataset(real_rgb_half1)
        real_dataset2 = CustomDataset(real_rgb_half2)

        real_vs_real_metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset1,
            input2=real_dataset2,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_real_fid"] = float(
            real_vs_real_metrics["frechet_inception_distance"]
        )

        print(
            f"siRNA {sirna_id}: Real vs Gen FID = {results['real_vs_generated_fid']:.4f} "
            f"(n_real={n_real_per_half}, n_gen={n_gen_to_use}), "
            f"Real vs Real FID = {results['real_vs_real_fid']:.4f} (n={n_real_per_half} each)"
        )

        return results

    except Exception as e:
        print(f"[error] FID calculation failed for siRNA {sirna_id}: {e}")
        return None


def calculate_unconditional_fid(
    all_fixed_real_images: List[torch.Tensor],  # Now expects pooled fixed subsets
    all_generated_images: List[torch.Tensor],
    num_samples: int = 5000,
) -> Optional[Dict]:
    """Calculate unconditional FID using pooled fixed real image subsets."""
    print(f"\nCalculating unconditional FID with up to {num_samples} samples each...")

    # Use fixed real images (already preselected per siRNA)
    if len(all_fixed_real_images) > num_samples:
        real_sample = random.sample(all_fixed_real_images, num_samples)
    else:
        real_sample = all_fixed_real_images

    # Ensure even number for clean splitting
    if len(real_sample) % 2 == 1:
        real_sample = real_sample[:-1]

    n_real_total = len(real_sample)
    n_real_per_half = n_real_total // 2

    # For generated images, use as many as available up to n_real_per_half
    if len(all_generated_images) > n_real_per_half:
        gen_sample = random.sample(all_generated_images, n_real_per_half)
    else:
        gen_sample = all_generated_images

    n_gen_to_use = len(gen_sample)

    try:
        # Sample crops from both real and generated images
        real_crops = sample_nuclei_crops(real_sample, n_real_total)
        if real_crops.shape[0] < n_real_total:
            print(
                f"[warn] Could only generate {real_crops.shape[0]} valid real crops for unconditional FID, expected {n_real_total}"
            )
            n_real_total = real_crops.shape[0]
            n_real_per_half = n_real_total // 2

        generated_crops = sample_nuclei_crops(gen_sample, n_gen_to_use)
        if generated_crops.shape[0] < n_gen_to_use:
            print(
                f"[warn] Could only generate {generated_crops.shape[0]} valid generated crops for unconditional FID, expected {n_gen_to_use}"
            )
            n_gen_to_use = generated_crops.shape[0]

    except Exception as e:
        print(f"[error] Failed to process images for unconditional FID: {e}")
        return None

    # Convert to RGB
    real_rgb = to_eval_rgb_uint8(real_crops)
    gen_rgb = to_eval_rgb_uint8(generated_crops)

    # Split fixed real images for real vs real comparison
    real_rgb_half1 = real_rgb[:n_real_per_half]
    real_rgb_half2 = real_rgb[n_real_per_half : 2 * n_real_per_half]

    # Use actual number of generated images
    gen_rgb_final = gen_rgb[:n_gen_to_use]

    results = {
        "sirna_id": "Unconditional",
        "num_real_raw_images": len(all_fixed_real_images),
        "num_real_crops_generated": n_real_total,
        "num_real_per_half": n_real_per_half,
        "num_generated_available": len(all_generated_images),
        "num_generated_used": n_gen_to_use,
    }

    try:
        # Real vs Generated FID (using n_real_per_half vs n_gen_to_use)
        real_dataset = CustomDataset(real_rgb_half1)
        gen_dataset = CustomDataset(gen_rgb_final)

        metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset,
            input2=gen_dataset,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_generated_fid"] = float(metrics["frechet_inception_distance"])

        # Real vs Real FID (using both halves of fixed real images)
        real_dataset1 = CustomDataset(real_rgb_half1)
        real_dataset2 = CustomDataset(real_rgb_half2)

        real_vs_real_metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset1,
            input2=real_dataset2,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_real_fid"] = float(
            real_vs_real_metrics["frechet_inception_distance"]
        )

        print(
            f"Unconditional: Real vs Gen FID = {results['real_vs_generated_fid']:.4f} "
            f"(n_real={n_real_per_half}, n_gen={n_gen_to_use}), "
            f"Real vs Real FID = {results['real_vs_real_fid']:.4f} (n={n_real_per_half} each)"
        )

        return results

    except Exception as e:
        print(f"[error] Unconditional FID calculation failed: {e}")
        return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="MorphGen comparison using datamodule real images and nuclei crops"
    )
    parser.add_argument("--seed", type=int, default=1337, help="Random seed")
    parser.add_argument(
        "--target_selection_seed",
        type=int,
        default=None,
        help="Seed for target real image selection (if None, uses main seed)",
    )
    parser.add_argument(
        "--num_perturbations",
        type=int,
        default=100,
        help="Number of perturbations to sample for comparison",
    )
    parser.add_argument(
        "--cell_line",
        type=str,
        default="U2OS",
        choices=["U2OS", "HUVEC", "RPE", "HEPG2"],
        help="Cell line to analyze",
    )
    parser.add_argument(
        "--min_samples",
        type=int,
        default=32,
        help="Minimum crop samples required per perturbation (raw images needed = min_samples // 2 since multiple crops per image)",
    )
    parser.add_argument(
        "--max_samples_per_perturbation",
        type=int,
        default=None,
        help="Maximum samples to use per perturbation",
    )
    parser.add_argument(
        "--unconditional_samples",
        type=int,
        default=2000,
        help="Number of samples for unconditional FID",
    )
    parser.add_argument(
        "--config_path",
        type=str,
        default="/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml",
        help="Path to config file for datamodule",
    )
    parser.add_argument(
        "--morphgen_data_path",
        type=str,
        default="/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data",
        help="Path to MorphGen generated images directory",
    )

    args = parser.parse_args()

    seed_everything(args.seed)

    # Use separate seed for target selection if provided
    target_seed = (
        args.target_selection_seed
        if args.target_selection_seed is not None
        else args.seed
    )

    # Paths (still using same metadata paths to get the 100 compounds mapping)
    cellflux_path = "/mnt/pvc/CellFlux/images/cellflux/rxrx1"
    metadata_path = "/mnt/pvc/IMPA_reproducibility/IMPA_reproducibility/datasets/rxrx1_extracted/rxrx1/metadata/rxrx1_df.csv"

    print(f"Using seed: {args.seed}")
    print(f"Cell line: {args.cell_line}")
    print(f"CellFlux compounds path: {cellflux_path}")
    print(f"MorphGen generated images path: {args.morphgen_data_path}")
    print(f"Metadata path: {metadata_path}")

    # Step 1: Get the 100 compound names from CellFlux directory (for mapping)
    cellflux_compounds = get_cellflux_compounds(cellflux_path)
    if len(cellflux_compounds) == 0:
        print("[error] No CellFlux compounds found. Exiting.")
        sys.exit(1)

    # Load configuration and create datamodule
    config = OmegaConf.load(args.config_path)
    datamodule = CellDataModule(config)

    # Load metadata
    metadata_df = pd.read_csv(metadata_path)
    print(f"Loaded metadata with {len(metadata_df)} rows")

    # Step 2: Filter CellFlux compounds to those that exist in metadata
    available_compounds = get_available_compounds(
        metadata_path, cellflux_compounds, args.cell_line
    )
    if len(available_compounds) == 0:
        print("[error] No available compounds found. Exiting.")
        sys.exit(1)

    # Step 3: Create mapping from compounds to siRNA IDs (only for the available compounds)
    compound_to_sirna_mapping = create_compound_to_sirna_mapping(
        datamodule, metadata_df, available_compounds, args.cell_line
    )

    if len(compound_to_sirna_mapping) == 0:
        print("[error] No compound-to-siRNA mappings found. Exiting.")
        sys.exit(1)

    print(
        f"Successfully mapped {len(compound_to_sirna_mapping)} compounds to siRNA IDs"
    )

    # Extract unique siRNA IDs from the mapping
    unique_sirna_ids = set(
        mapping["sirna_id"] for mapping in compound_to_sirna_mapping.values()
    )
    sirna_id_list = sorted(list(unique_sirna_ids))

    print(f"Found {len(sirna_id_list)} unique siRNA IDs for analysis")

    # Sample siRNA IDs for evaluation
    if len(sirna_id_list) > args.num_perturbations:
        sampled_sirna_ids = random.sample(sirna_id_list, args.num_perturbations)
    else:
        sampled_sirna_ids = sirna_id_list

    print(f"Analyzing {len(sampled_sirna_ids)} siRNA IDs: {sampled_sirna_ids[:10]}...")

    # Collect results
    conditional_results = []
    all_fixed_real_images = []  # Store preselected fixed real subsets
    all_generated_images = []

    print("\n" + "=" * 80)
    print("PRESELECTING FIXED REAL IMAGE SUBSETS")
    print("=" * 80)

    # First pass: collect and preselect fixed real image subsets for each siRNA
    sirna_fixed_real_images = {}  # Store fixed subsets per siRNA

    for sirna_id in tqdm(sampled_sirna_ids, desc="Preselecting real images"):
        # Collect all real images for this siRNA from datamodule
        all_real_images = collect_real_images_from_datamodule(
            sirna_id,
            datamodule,
            args.cell_line,
            max_samples=None,  # Don't limit here, we'll preselect properly
        )

        # if (
        #     len(all_real_images) < args.min_samples // 2
        # ):  # Allow fewer raw images since we can generate multiple crops per image
        #     print(
        #         f"[warn] Skipping siRNA {sirna_id}: only {len(all_real_images)} real images, need at least {args.min_samples // 2} (will generate multiple crops per image)"
        #     )
        #     continue

        # Preselect fixed subset
        fixed_real_subset = preselect_target_real_images(
            sirna_id,
            all_real_images,
            args.min_samples,
            args.max_samples_per_perturbation,
            target_seed,  # Use target selection seed
        )

        sirna_fixed_real_images[sirna_id] = fixed_real_subset
        all_fixed_real_images.extend(fixed_real_subset)

    print(
        f"Preselected fixed real subsets for {len(sirna_fixed_real_images)} siRNA IDs"
    )
    print(f"Total fixed real images: {len(all_fixed_real_images)}")

    print("\n" + "=" * 80)
    print("CALCULATING CONDITIONAL FID PER SIRNA")
    print("=" * 80)

    for sirna_id in tqdm(sampled_sirna_ids, desc="Processing siRNA IDs"):
        if sirna_id not in sirna_fixed_real_images:
            continue  # Skip siRNAs that didn't meet minimum requirements

        # Get preselected fixed real images
        fixed_real_images = sirna_fixed_real_images[sirna_id]

        # Collect MorphGen generated images
        generated_images = collect_morphgen_generated_images(
            sirna_id,
            args.morphgen_data_path,
            cell_line_code=0,  # 0 for U2OS
            max_samples=args.max_samples_per_perturbation,
        )

        # Calculate FID for this siRNA using fixed real subset
        result = calculate_fid_for_sirna(
            sirna_id,
            fixed_real_images,  # Use preselected fixed subset
            generated_images,
            args.min_samples,
        )

        if result is not None:
            conditional_results.append(result)
            # Add to global pools for unconditional FID
            all_generated_images.extend(generated_images)

    print(f"\nSuccessfully processed {len(conditional_results)} siRNA IDs")

    # Calculate unconditional FID
    print("\n" + "=" * 80)
    print("CALCULATING UNCONDITIONAL FID")
    print("=" * 80)

    unconditional_result = calculate_unconditional_fid(
        all_fixed_real_images,  # Use pooled fixed real subsets
        all_generated_images,
        args.unconditional_samples,
    )

    # Compile results
    results_df = pd.DataFrame(conditional_results)

    if not results_df.empty:
        # Add summary statistics
        avg_conditional_fid = float(results_df["real_vs_generated_fid"].mean())
        avg_conditional_real_vs_real_fid = float(results_df["real_vs_real_fid"].mean())

        summary_row = {
            "sirna_id": "Average_Conditional",
            "num_real_raw_images": float(results_df["num_real_raw_images"].mean()),
            "num_real_crops_generated": float(
                results_df["num_real_crops_generated"].mean()
            ),
            "num_real_per_half": float(results_df["num_real_per_half"].mean()),
            "num_generated_available": float(
                results_df["num_generated_available"].mean()
            ),
            "num_generated_used": float(results_df["num_generated_used"].mean()),
            "real_vs_generated_fid": avg_conditional_fid,
            "real_vs_real_fid": avg_conditional_real_vs_real_fid,
        }

        # Add summary and unconditional results
        all_results = conditional_results + [summary_row]
        if unconditional_result is not None:
            all_results.append(unconditional_result)

        results_df = pd.DataFrame(all_results)
    else:
        print("[warn] No valid conditional FID results")
        if unconditional_result is not None:
            results_df = pd.DataFrame([unconditional_result])

    # Save results
    output_file = (
        f"MMmorphgen_datamodule_comparison_{args.cell_line}_seed_{args.seed}.csv"
    )
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print summary
    print("\n" + "=" * 80)
    print("SUMMARY OF RESULTS")
    print("=" * 80)

    if unconditional_result is not None:
        print(
            f"Unconditional FID (Real vs Generated): {unconditional_result['real_vs_generated_fid']:.4f}"
        )
        print(
            f"Unconditional FID (Real vs Real): {unconditional_result['real_vs_real_fid']:.4f}"
        )

    if not results_df.empty and "Average_Conditional" in results_df["sirna_id"].values:
        avg_row = results_df[results_df["sirna_id"] == "Average_Conditional"].iloc[0]
        print(
            f"Average Conditional FID (Real vs Generated): {avg_row['real_vs_generated_fid']:.4f}"
        )
        print(
            f"Average Conditional FID (Real vs Real): {avg_row['real_vs_real_fid']:.4f}"
        )
        print(f"Number of siRNA IDs evaluated: {len(conditional_results)}")

    print(
        f"\nProcessed {len(all_fixed_real_images)} fixed real images and {len(all_generated_images)} generated images total"
    )
    print("\nDetailed Results:")
    print(results_df.to_string(index=False))
