"""
Hybrid CellFlux Preprocessed Comparison Script (CORRECTED VERSION)

This script merges approaches from perturbation_type_cellflux_comparison_2.py and
perturbation_type_cellflux_preprocessed_comparison.py to compare:
- Real images: Preprocessed .npy files from RxRx1 dataset (what CellFlux was trained on)
- Generated images: REPA generated .npy files from our model

KEY CORRECTIONS IN THIS VERSION:
1. **Fixed Target Distribution**: For each compound, we preselect a fixed subset of real images
   that serves as our consistent target distribution across all FID calculations.
2. **Consistent Sample Sizes**: All FID calculations use the same fixed real subsets, ensuring
   reproducible and fair comparisons.
3. **Proper Generated Sampling**: Generated images are sampled up to the size of half the fixed
   real subset, maintaining fairness without reducing the target distribution size.

The key innovation is matching perturbation compounds between:
1. RxRx1 compounds (organized by compound names)
2. Our datamodule perturbations (organized by numeric siRNA IDs)
3. REPA generated data (organized by perturbation IDs)

Usage:
    python perturbation_type_cellflux_preprocessed_hybrid_comparison_corrected.py --num_perturbations 50
"""

import argparse
import glob
import logging
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch_fidelity
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything

# Import from your modules
from sc_perturb.dataset import CellDataModule, to_rgb
from tqdm import tqdm

# =========================
# Constants for cropping pipeline
# =========================
NUCLEUS_CHANNEL = 0  # nucleus channel index
CROP_SIZE = 96
BORDER_MARGIN = CROP_SIZE // 2
MIN_FG_FRAC = 0.01  # min fraction of nucleus-foreground pixels in crop
MIN_NUCLEUS_MEAN = 0.015  # min mean intensity of nucleus channel in crop
MAX_TRIALS_FACTOR = 50  # we allow up to factor * requested samples attempts


class CustomDataset(torch.utils.data.Dataset):
    """Dataset wrapper for torch_fidelity"""

    def __init__(self, data_4d):
        self.data = data_4d

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]


def normalize_image_to_01(img: torch.Tensor) -> torch.Tensor:
    """Normalize image per-image to [0,1] range for better illumination matching."""
    img_min = img.min()
    img_max = img.max()
    if img_max > img_min:
        normalized = (img - img_min) / (img_max - img_min)
    else:
        normalized = img  # Constant image, keep as is
    return normalized


def safe_load_npy(file_path: str) -> Optional[torch.Tensor]:
    """Load .npy as torch.FloatTensor [C,H,W] scaled to [0,1]."""
    try:
        arr = np.load(file_path)
        ten = torch.from_numpy(arr)
        if ten.ndim == 2:  # [H,W] -> [1,H,W]
            ten = ten.unsqueeze(0)
        elif ten.ndim == 3:
            # If last dim looks like channels, convert HWC->CHW
            if ten.shape[-1] in (1, 3, 4, 5, 6):
                ten = ten.permute(2, 0, 1).contiguous()

        ten = ten.float()
        # Normalize to [0,1] if needed
        mx = ten.max().item() if ten.numel() > 0 else 1.0
        if mx > 1.5:
            ten = ten / mx
        return ten.clamp(0, 1)
    except Exception as e:
        print(f"[warn] failed to load {file_path}: {e}")
        return None


def safe_load_png(file_path: str) -> Optional[torch.Tensor]:
    """Load .png as torch.FloatTensor [C,H,W] scaled to [0,1]."""
    try:
        img = Image.open(file_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        ten = TF.to_tensor(img)  # [C, H, W] in range [0, 1]
        return ten
    except Exception as e:
        print(f"[warn] failed to load {file_path}: {e}")
        return None


def get_cellflux_compounds(cellflux_path: str) -> List[str]:
    """Get the 100 compound names from CellFlux directory."""
    try:
        if not os.path.exists(cellflux_path):
            print(f"[error] CellFlux path does not exist: {cellflux_path}")
            return []

        # Get all directories (compound names) from CellFlux
        compounds = [
            d
            for d in os.listdir(cellflux_path)
            if os.path.isdir(os.path.join(cellflux_path, d))
        ]

        print(f"Found {len(compounds)} compounds in CellFlux directory")
        print(f"First 10 compounds: {compounds[:10]}")

        return sorted(compounds)

    except Exception as e:
        print(f"[error] Failed to read CellFlux directory: {e}")
        return []


def get_available_compounds(
    metadata_path: str, cellflux_compounds: List[str], cell_line: str = "U2OS"
) -> List[str]:
    """Filter CellFlux compounds to only those present in metadata."""
    try:
        df = pd.read_csv(metadata_path)
        metadata_compounds = set(
            df[(df["CELL_LINE"] == cell_line) & (df["ANNOT"] == "treated")][
                "CPD_NAME"
            ].unique()
        )

        # Keep only CellFlux compounds that exist in metadata
        available_compounds = [
            compound
            for compound in cellflux_compounds
            if compound in metadata_compounds
        ]

        print(
            f"Found {len(available_compounds)} CellFlux compounds that exist in metadata for {cell_line} treated samples"
        )

        missing_compounds = [
            compound
            for compound in cellflux_compounds
            if compound not in metadata_compounds
        ]

        if missing_compounds:
            print(
                f"[warn] {len(missing_compounds)} CellFlux compounds not found in metadata: {missing_compounds[:5]}"
            )

        return available_compounds

    except Exception as e:
        print(f"[error] Failed to load metadata: {e}")
        return []


def create_compound_to_sirna_mapping(
    datamodule: CellDataModule,
    metadata_df: pd.DataFrame,
    available_compounds: List[str],
    cell_line: str = "U2OS",
) -> Dict[str, Dict]:
    """Create mapping from compound names to siRNA IDs using sample key matching."""
    print("Creating compound to siRNA mapping...")

    # Get MorphGen metadata
    morphgen_df = datamodule.metadata

    # Filter for the specified cell line
    morphgen_cellline = morphgen_df[morphgen_df["cell_type"] == cell_line]
    print(f"Found {len(morphgen_cellline)} {cell_line} samples in MorphGen")

    # Filter IMPA metadata for treated samples with available compounds
    impa_filtered = metadata_df[
        (metadata_df["CPD_NAME"].isin(available_compounds))
        & (metadata_df["ANNOT"] == "treated")
        & (metadata_df["CELL_LINE"] == cell_line)
    ]
    print(
        f"Found {len(impa_filtered)} IMPA records for available compounds ({cell_line}, treated)"
    )

    # Create mapping by matching sample keys
    compound_to_sirna_mapping = {}
    successful_matches = 0

    print("Matching sample keys between MorphGen and IMPA...")

    # Group IMPA data by base key (without frame number) for efficiency
    impa_base_keys = {}
    for _, row in impa_filtered.iterrows():
        sample_key = row["SAMPLE_KEY"]
        compound = row["CPD_NAME"]

        # Parse SAMPLE_KEY: e.g., "U2OS-01_1_B02_s1_14"
        # Extract base: "U2OS-01_1_B02_s1"
        parts = sample_key.split("_")
        if len(parts) >= 4:
            base_key = "_".join(parts[:-1])  # Remove frame number
            if base_key not in impa_base_keys:
                impa_base_keys[base_key] = []
            impa_base_keys[base_key].append(compound)

    print(f"Created {len(impa_base_keys)} unique base keys from IMPA data")

    # Match with MorphGen data
    for _, morphgen_row in morphgen_cellline.iterrows():
        site_id = morphgen_row["site_id"]
        sirna = morphgen_row["sirna"]
        sirna_id = morphgen_row["sirna_id"]

        # Convert MorphGen site_id to IMPA base key format
        # MorphGen: "U2OS-01_1_B02_1" → IMPA base: "U2OS-01_1_B02_s1"
        parts = site_id.split("_")
        if len(parts) >= 4:
            # Construct IMPA-style base key
            experiment = parts[0]  # "U2OS-01"
            plate = parts[1]  # "1"
            well = parts[2]  # "B02"
            site = parts[3]  # "1"

            impa_base_key = f"{experiment}_{plate}_{well}_s{site}"

            # Check if this base key exists in IMPA data
            if impa_base_key in impa_base_keys:
                compounds = impa_base_keys[impa_base_key]

                # All compounds at this location should be the same, take the first one
                compound = compounds[0]

                if compound in available_compounds:
                    if compound not in compound_to_sirna_mapping:
                        compound_to_sirna_mapping[compound] = {
                            "sirna": sirna,
                            "sirna_id": sirna_id,
                            "sample_count": 0,
                        }
                    compound_to_sirna_mapping[compound]["sample_count"] += 1
                    successful_matches += 1

    print(
        f"Successfully created mappings for {len(compound_to_sirna_mapping)} compounds"
    )
    print(f"Total matched samples: {successful_matches}")

    # Validate that we have exactly the expected number of compounds
    if len(compound_to_sirna_mapping) != len(available_compounds):
        missing_compounds = set(available_compounds) - set(
            compound_to_sirna_mapping.keys()
        )
        print(
            f"[warn] Missing mappings for {len(missing_compounds)} compounds: {list(missing_compounds)[:10]}"
        )

    # Extract unique siRNA IDs
    unique_sirna_ids = set()
    unique_sirnas = set()
    for compound, mapping in compound_to_sirna_mapping.items():
        unique_sirna_ids.add(mapping["sirna_id"])
        unique_sirnas.add(mapping["sirna"])

    print(f"Mapped to {len(unique_sirna_ids)} unique siRNA IDs")
    print(f"Mapped to {len(unique_sirnas)} unique siRNA names")

    # Show the siRNA ID distribution
    sirna_id_counts = {}
    for compound, mapping in compound_to_sirna_mapping.items():
        sirna_id = mapping["sirna_id"]
        if sirna_id not in sirna_id_counts:
            sirna_id_counts[sirna_id] = []
        sirna_id_counts[sirna_id].append(compound)

    print(f"siRNA ID to compound mapping (first 10):")
    for i, (sirna_id, compounds) in enumerate(sorted(sirna_id_counts.items())[:10]):
        print(f"  siRNA ID {sirna_id}: {compounds}")

    if len(unique_sirna_ids) != 100:
        print(
            f"[warn] Expected 100 unique siRNA IDs for CellFlux compounds, but found {len(unique_sirna_ids)}"
        )
    else:
        print(f"✓ Successfully validated 100 unique siRNA IDs for CellFlux compounds")

    return compound_to_sirna_mapping


def collect_cellflux_generated_images(
    compound: str,
    cellflux_data_path: str,
    max_samples: Optional[int] = None,
) -> List[torch.Tensor]:
    """Collect CellFlux generated PNG images for a compound (already RGB 96x96)."""
    compound_dir = os.path.join(cellflux_data_path, compound)
    if not os.path.exists(compound_dir):
        return []

    # Find all PNG files for this compound
    png_files = glob.glob(os.path.join(compound_dir, "*.png"))

    if max_samples and len(png_files) > max_samples:
        png_files = random.sample(png_files, max_samples)

    images = []
    for file_path in tqdm(
        png_files, desc=f"Loading CellFlux images for {compound}", leave=False
    ):
        img = safe_load_png(file_path)  # Already RGB [3,H,W] in [0,1]
        if img is not None:
            images.append(img)

    return images


def collect_repa_generated_images(
    sirna_id: int,
    repa_data_path: str,
    cell_line_code: int = 0,  # 0 for U2OS
    max_samples: Optional[int] = None,
) -> List[torch.Tensor]:
    """Collect REPA generated NPY images for a siRNA ID."""
    # REPA data is organized as p{sirna_id}/p{sirna_id}_c{cell_line_code}_sample{sample_id}.npy
    perturbation_dir = os.path.join(repa_data_path, f"p{sirna_id}")
    if not os.path.exists(perturbation_dir):
        return []

    # Find all files for this perturbation and cell line
    pattern = f"p{sirna_id}_c{cell_line_code}_sample*.npy"
    npy_files = glob.glob(os.path.join(perturbation_dir, pattern))

    if max_samples and len(npy_files) > max_samples:
        npy_files = random.sample(npy_files, max_samples)

    images = []
    for file_path in tqdm(
        npy_files, desc=f"Loading REPA images for siRNA ID {sirna_id}", leave=False
    ):
        img = safe_load_npy(file_path)
        if img is not None:
            # Normalize each generated image to [0,1] range for better illumination matching
            img_normalized = normalize_image_to_01(img)
            # img_normalized = img
            images.append(img_normalized)

    return images


def collect_real_preprocessed_images(
    compound: str,
    metadata_df: pd.DataFrame,
    real_data_path: str,
    cell_line: str = "U2OS",
    max_samples: Optional[int] = None,
) -> List[torch.Tensor]:
    """Collect real preprocessed NPY images for a compound."""
    # Filter metadata for this compound and cell line
    compound_data = metadata_df[
        (metadata_df["CPD_NAME"] == compound)
        & (metadata_df["CELL_LINE"] == cell_line)
        & (metadata_df["ANNOT"] == "treated")
    ]

    if len(compound_data) == 0:
        return []

    images = []
    sample_keys = list(compound_data["SAMPLE_KEY"].values)

    if max_samples and len(sample_keys) > max_samples:
        sample_keys = random.sample(sample_keys, max_samples)

    for sample_key in tqdm(
        sample_keys, desc=f"Loading real images for {compound}", leave=False
    ):
        try:
            # Parse sample key: e.g., "U2OS-01_2_G16_s1_13"
            parts = sample_key.split("_")
            if len(parts) >= 4:
                plate_info = "_".join(parts[:2])  # "U2OS-01_2"
                plate_dir = plate_info.split("-")[1]  # "01_2"
                well = parts[2]  # "G16"
                site_frame = "_".join(parts[3:])  # "s1_13"

                npy_file = os.path.join(
                    real_data_path, plate_dir, well, f"{site_frame}.npy"
                )

                if os.path.exists(npy_file):
                    img = safe_load_npy(npy_file)
                    if img is not None:
                        images.append(img)

        except Exception as e:
            print(f"[warn] Failed to parse sample key {sample_key}: {e}")

    return images


# =========================
# Nuclei-centered cropping (from comparison_2.py)
# =========================
def otsu_threshold_approx(x: torch.Tensor) -> float:
    """Approximate Otsu threshold for x in [0,1]."""
    xv = x.detach().cpu().flatten()
    if xv.numel() == 0:
        return 0.0
    hist = torch.histc(xv, bins=256, min=0.0, max=1.0)
    p = hist / hist.sum().clamp(min=1)
    omega = torch.cumsum(p, dim=0)
    mu = torch.cumsum(p * torch.arange(256, dtype=torch.float32), dim=0)
    mu_t = mu[-1]
    sigma_b2 = (mu_t * omega - mu) ** 2 / (omega * (1.0 - omega)).clamp(min=1e-8)
    sigma_b2[torch.isnan(sigma_b2)] = -1
    k = int(torch.argmax(sigma_b2).item())
    return k / 255.0


def pick_nucleus_center(mask: torch.Tensor, margin: int) -> Optional[Tuple[int, int]]:
    """Pick random nucleus center from mask, respecting border margins."""
    H, W = mask.shape
    mask = mask.clone()
    if margin > 0:
        mask[:margin, :] = 0
        mask[-margin:, :] = 0
        mask[:, :margin] = 0
        mask[:, -margin:] = 0
    coords = mask.nonzero(as_tuple=False)
    if coords.numel() == 0:
        return None
    yx = coords[torch.randint(0, coords.shape[0], (1,)).item()]
    return int(yx[0].item()), int(yx[1].item())


def crop_centered(
    img: torch.Tensor, center: Tuple[int, int], size: int
) -> torch.Tensor:
    """Crop image centered at given point with reflection padding if needed."""
    C, H, W = img.shape
    half = size // 2
    y, x = center
    y0 = max(0, y - half)
    x0 = max(0, x - half)
    y1 = min(H, y0 + size)
    x1 = min(W, x0 + size)
    crop = img[:, y0:y1, x0:x1]

    # Pad if near borders
    pad_h = size - crop.shape[1]
    pad_w = size - crop.shape[2]
    if pad_h > 0 or pad_w > 0:
        crop = F.pad(crop, (0, pad_w, 0, pad_h), mode="reflect")
    return crop


def nuclei_centered_crop(
    img: torch.Tensor,
    nucleus_channel: int = NUCLEUS_CHANNEL,
    crop_size: int = CROP_SIZE,
    border_margin: int = BORDER_MARGIN,
    min_fg_frac: float = MIN_FG_FRAC,
    min_nucleus_mean: float = MIN_NUCLEUS_MEAN,
) -> Optional[torch.Tensor]:
    """Extract nucleus-centered crop from image."""
    assert img.ndim == 3, "Expected [C,H,W]"
    nuc = img[nucleus_channel].clamp(0, 1)
    thr = otsu_threshold_approx(nuc)
    mask = nuc >= thr

    center = pick_nucleus_center(mask, border_margin)
    if center is None:
        return None

    crop = crop_centered(img, center, crop_size)
    nuc_crop = crop[nucleus_channel]
    fg_frac = (nuc_crop >= thr).float().mean().item()
    if fg_frac < min_fg_frac:
        return None
    if nuc_crop.mean().item() < min_nucleus_mean:
        return None
    return crop


def sample_nuclei_crops(
    images: List[torch.Tensor],
    num_samples: int,
    max_trials_factor: int = MAX_TRIALS_FACTOR,
) -> torch.Tensor:
    """Sample nucleus-centered crops from image list."""
    if len(images) == 0:
        raise ValueError("Empty image pool.")

    crops = []
    attempts = 0
    max_attempts = num_samples * max_trials_factor

    while len(crops) < num_samples and attempts < max_attempts:
        attempts += 1
        img = images[random.randrange(len(images))]
        crop = nuclei_centered_crop(img)
        if crop is not None:
            crops.append(crop)

    if len(crops) < num_samples:
        print(
            f"[warn] Only collected {len(crops)}/{num_samples} valid crops after {attempts} attempts."
        )

    return torch.stack(crops) if len(crops) > 0 else torch.empty(0)


def to_rgb_batch(x: torch.Tensor) -> torch.Tensor:
    """Convert batch to RGB using the to_rgb function from dataset."""
    if x.ndim == 3:
        rgb = to_rgb(x[None].cpu()).squeeze(0)  # [3,H,W]
        return rgb[None]
    elif x.ndim == 4:
        rgbs = []
        for i in range(x.shape[0]):
            rgb = to_rgb(x[i][None].cpu()).squeeze(0)  # [3,H,W]
            rgbs.append(rgb)
        return torch.stack(rgbs, 0)
    else:
        raise ValueError(f"Unexpected tensor shape: {x.shape}")


def to_eval_rgb_uint8(batch_chw: torch.Tensor) -> torch.Tensor:
    """Convert to RGB uint8 for FID evaluation."""
    rgb = to_rgb_batch(batch_chw)  # float [0,1]
    rgb = (rgb.clamp(0, 1) * 255.0).to(torch.uint8)
    return rgb


def images_to_rgb_uint8_preprocessed(images: List[torch.Tensor]) -> torch.Tensor:
    """Convert preprocessed images (6-channel) to RGB uint8 for FID."""
    if len(images) == 0:
        return torch.empty(0, 3, 0, 0, dtype=torch.uint8)

    rgb_images = []
    for img in images:
        if img.shape[0] >= 3:  # Multi-channel (take first 3)
            rgb = img[:3]
        elif img.shape[0] == 1:  # Single channel
            rgb = img.repeat(3, 1, 1)
        else:
            rgb = img

        rgb = rgb.clamp(0, 1)
        rgb_uint8 = (rgb * 255).to(torch.uint8)
        rgb_images.append(rgb_uint8)

    return torch.stack(rgb_images, dim=0)


def images_to_rgb_uint8_png(images: List[torch.Tensor]) -> torch.Tensor:
    """Convert PNG images (already RGB) to uint8 for FID."""
    if len(images) == 0:
        return torch.empty(0, 3, 0, 0, dtype=torch.uint8)

    rgb_images = []
    for img in images:
        # PNG images from CellFlux are already RGB [3, H, W]
        rgb = img.clamp(0, 1)
        rgb_uint8 = (rgb * 255).to(torch.uint8)
        rgb_images.append(rgb_uint8)

    return torch.stack(rgb_images, dim=0)


def save_sample_images(real_rgb, gen_rgb, compound, output_dir="sample_images"):
    """
    Save sample real and generated RGB images for visual inspection.

    Args:
        real_rgb: [N, 3, H, W] tensor of real RGB images (uint8)
        gen_rgb: [N, 3, H, W] tensor of generated RGB images (uint8)
        compound: compound name for naming files
        output_dir: directory to save images
    """
    os.makedirs(output_dir, exist_ok=True)

    # Take up to 8 samples from each
    n_real = min(8, real_rgb.shape[0])
    n_gen = min(8, gen_rgb.shape[0])

    print(f"Saving {n_real} real and {n_gen} generated images for {compound}")

    # Save individual images
    for i in range(n_real):
        real_img = TF.to_pil_image(real_rgb[i])
        real_path = os.path.join(output_dir, f"{compound}_real_{i:02d}.png")
        real_img.save(real_path)

    for i in range(n_gen):
        gen_img = TF.to_pil_image(gen_rgb[i])
        gen_path = os.path.join(output_dir, f"{compound}_gen_{i:02d}.png")
        gen_img.save(gen_path)

    print(f"Saved sample images for {compound} to {output_dir}")


def images_to_rgb_uint8_repa(images: List[torch.Tensor]) -> torch.Tensor:
    """Convert REPA generated images (6-channel) to RGB uint8 for FID."""
    if len(images) == 0:
        return torch.empty(0, 3, 0, 0, dtype=torch.uint8)

    # Stack all images into a batch
    batch_tensor = torch.stack(images, dim=0)  # [N, 6, H, W]

    # Convert to RGB using the to_rgb function
    rgb_batch = to_rgb_batch(batch_tensor)  # [N, 3, H, W] in [0,1]

    # Convert to uint8
    rgb_uint8 = (rgb_batch.clamp(0, 1) * 255.0).to(torch.uint8)

    return rgb_uint8


def preselect_target_real_images(
    compound: str,
    real_images: List[torch.Tensor],
    min_samples: int = 32,
    max_samples: Optional[int] = None,
    seed: int = 42,
) -> List[torch.Tensor]:
    """
    Preselect a fixed subset of real images to serve as target distribution.
    This subset will be used consistently across all FID calculations.
    """
    if len(real_images) < min_samples:
        return real_images

    # Set seed for reproducible selection
    random.seed(
        seed + hash(compound) % 1000
    )  # Make seed compound-specific but reproducible

    # Determine target size
    if max_samples is None:
        target_size = len(real_images)
    else:
        target_size = min(max_samples, len(real_images))

    # Ensure even number for clean splitting
    if target_size % 2 == 1:
        target_size -= 1

    # Select subset
    selected_images = random.sample(real_images, target_size)

    print(
        f"[{compound}] Selected {len(selected_images)} real images as fixed target distribution"
    )

    # Reset random seed
    random.seed()

    return selected_images


def calculate_fid_for_compound(
    compound: str,
    fixed_real_images: List[torch.Tensor],  # Now expects pre-selected fixed subset
    generated_images: List[torch.Tensor],
    min_samples: int = 32,
    use_cellflux_generated: bool = False,
) -> Optional[Dict]:
    """Calculate FID metrics for a single compound using fixed real image subset."""
    if len(fixed_real_images) < min_samples:
        print(
            f"[warn] Insufficient fixed real samples for {compound}: {len(fixed_real_images)}"
        )
        return None

    # Use the fixed real images (already preselected)
    n_real_total = len(fixed_real_images)
    n_real_per_half = n_real_total // 2  # Split fixed real images into two halves

    # For generated images, use as many as available up to n_real_per_half
    n_gen_to_use = min(len(generated_images), n_real_per_half)

    print(
        f"Debug: Processing {compound} with {len(fixed_real_images)} fixed real, {len(generated_images)} available generated images"
    )
    print(
        f"Debug: Will use {n_real_per_half} real images per half, {n_gen_to_use} generated samples"
    )

    if len(fixed_real_images) > 0:
        print(
            f"Debug: Fixed real image shape: {fixed_real_images[0].shape}, range: [{fixed_real_images[0].min():.4f}, {fixed_real_images[0].max():.4f}]"
        )
    if len(generated_images) > 0:
        print(
            f"Debug: Generated image shape: {generated_images[0].shape}, range: [{generated_images[0].min():.4f}, {generated_images[0].max():.4f}]"
        )

    try:
        # Use all fixed real images
        real_crops = torch.stack(fixed_real_images)

        # Handle generated images based on source
        if use_cellflux_generated:
            # CellFlux images are already RGB 96x96, no cropping needed
            generated_crops_rgb = images_to_rgb_uint8_png(
                generated_images[:n_gen_to_use]
            )
        else:
            # REPA images need nuclei-centered cropping from 512x512 to 96x96, then RGB conversion
            generated_crops = sample_nuclei_crops(generated_images, n_gen_to_use)
            if generated_crops.shape[0] < n_gen_to_use:
                print(
                    f"[warn] Could only generate {generated_crops.shape[0]} valid crops for {compound}, expected {n_gen_to_use}"
                )
                n_gen_to_use = generated_crops.shape[0]  # Update to actual number
            generated_crops_rgb = to_eval_rgb_uint8(generated_crops)

    except Exception as e:
        print(f"[warn] Failed to process images for {compound}: {e}")
        return None

    # Process fixed real images (6-channel crops -> RGB)
    real_rgb = to_eval_rgb_uint8(real_crops)

    # Split fixed real images into two equal halves for real vs real comparison
    real_rgb_half1 = real_rgb[:n_real_per_half]
    real_rgb_half2 = real_rgb[n_real_per_half : 2 * n_real_per_half]

    # Use the actual number of generated images we have (up to n_real_per_half)
    gen_rgb_final = generated_crops_rgb[:n_gen_to_use]

    print(
        f"Debug: Final sizes - real_half1: {real_rgb_half1.shape[0]}, real_half2: {real_rgb_half2.shape[0]}, gen: {gen_rgb_final.shape[0]}"
    )

    results = {
        "compound": compound,
        "num_real_total": len(fixed_real_images),
        "num_real_per_half": n_real_per_half,
        "num_generated_available": len(generated_images),
        "num_generated_used": n_gen_to_use,
    }

    try:
        # Real vs Generated FID (use n_real_per_half vs n_gen_to_use)
        real_dataset = CustomDataset(real_rgb_half1)
        gen_dataset = CustomDataset(gen_rgb_final)

        metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset,
            input2=gen_dataset,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_generated_fid"] = float(metrics["frechet_inception_distance"])

        # Real vs Real FID (using both halves of fixed real images)
        real_dataset1 = CustomDataset(real_rgb_half1)
        real_dataset2 = CustomDataset(real_rgb_half2)

        real_vs_real_metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset1,
            input2=real_dataset2,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_real_fid"] = float(
            real_vs_real_metrics["frechet_inception_distance"]
        )

        print(
            f"{compound}: Real vs Gen FID = {results['real_vs_generated_fid']:.4f} "
            f"(n_real={n_real_per_half}, n_gen={n_gen_to_use}), "
            f"Real vs Real FID = {results['real_vs_real_fid']:.4f} (n={n_real_per_half} each)"
        )

        return results

    except Exception as e:
        print(f"[error] FID calculation failed for {compound}: {e}")
        return None


def calculate_unconditional_fid(
    all_fixed_real_images: List[torch.Tensor],  # Now expects pooled fixed subsets
    all_generated_images: List[torch.Tensor],
    num_samples: int = 5000,
    use_cellflux_generated: bool = False,
) -> Optional[Dict]:
    """Calculate unconditional FID using pooled fixed real image subsets."""
    print(f"\nCalculating unconditional FID with up to {num_samples} samples each...")

    # Use fixed real images (already preselected per compound)
    if len(all_fixed_real_images) > num_samples:
        real_sample = random.sample(all_fixed_real_images, num_samples)
    else:
        real_sample = all_fixed_real_images

    # Ensure even number for clean splitting
    if len(real_sample) % 2 == 1:
        real_sample = real_sample[:-1]

    n_real_total = len(real_sample)
    n_real_per_half = n_real_total // 2

    # For generated images, use as many as available up to n_real_per_half
    if len(all_generated_images) > n_real_per_half:
        gen_sample = random.sample(all_generated_images, n_real_per_half)
    else:
        gen_sample = all_generated_images

    n_gen_to_use = len(gen_sample)

    # Handle generated images based on source
    try:
        # Real images are already cropped, just convert to RGB
        real_rgb = to_eval_rgb_uint8(torch.stack(real_sample))

        if use_cellflux_generated:
            # CellFlux images are already RGB 96x96, no cropping needed
            gen_rgb = images_to_rgb_uint8_png(gen_sample)
        else:
            # REPA images need nuclei-centered cropping from 512x512 to 96x96
            gen_crops = sample_nuclei_crops(gen_sample, n_gen_to_use)
            if gen_crops.shape[0] < n_gen_to_use:
                print(
                    f"[warn] Could only generate {gen_crops.shape[0]} valid crops for unconditional FID, expected {n_gen_to_use}"
                )
                n_gen_to_use = gen_crops.shape[0]
            gen_rgb = to_eval_rgb_uint8(gen_crops)

    except Exception as e:
        print(f"[error] Failed to process images for unconditional FID: {e}")
        return None

    # Split fixed real images for real vs real comparison
    real_rgb_half1 = real_rgb[:n_real_per_half]
    real_rgb_half2 = real_rgb[n_real_per_half : 2 * n_real_per_half]

    # Use actual number of generated images
    gen_rgb_final = gen_rgb[:n_gen_to_use]

    results = {
        "compound": "Unconditional",
        "num_real_total": len(all_fixed_real_images),
        "num_real_per_half": n_real_per_half,
        "num_generated_available": len(all_generated_images),
        "num_generated_used": n_gen_to_use,
    }

    try:
        # Real vs Generated FID (using n_real_per_half vs n_gen_to_use)
        real_dataset = CustomDataset(real_rgb_half1)
        gen_dataset = CustomDataset(gen_rgb_final)

        metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset,
            input2=gen_dataset,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_generated_fid"] = float(metrics["frechet_inception_distance"])

        # Real vs Real FID (using both halves of fixed real images)
        real_dataset1 = CustomDataset(real_rgb_half1)
        real_dataset2 = CustomDataset(real_rgb_half2)

        real_vs_real_metrics = torch_fidelity.calculate_metrics(
            input1=real_dataset1,
            input2=real_dataset2,
            cuda=torch.cuda.is_available(),
            fid=True,
            kid=False,
            verbose=False,
        )
        results["real_vs_real_fid"] = float(
            real_vs_real_metrics["frechet_inception_distance"]
        )

        print(
            f"Unconditional: Real vs Gen FID = {results['real_vs_generated_fid']:.4f} "
            f"(n_real={n_real_per_half}, n_gen={n_gen_to_use}), "
            f"Real vs Real FID = {results['real_vs_real_fid']:.4f} (n={n_real_per_half} each)"
        )

        return results

    except Exception as e:
        print(f"[error] Unconditional FID calculation failed: {e}")
        return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Hybrid comparison of CellFlux generated vs RxRx1 preprocessed real images"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument(
        "--target_selection_seed",
        type=int,
        default=None,
        help="Seed for target real image selection (if None, uses main seed)",
    )
    parser.add_argument(
        "--num_compounds",
        type=int,
        default=100,
        help="Number of compounds to sample for comparison",
    )
    parser.add_argument(
        "--cell_line",
        type=str,
        default="U2OS",
        choices=["U2OS", "HUVEC", "RPE", "HEPG2"],
        help="Cell line to analyze",
    )
    parser.add_argument(
        "--min_samples",
        type=int,
        default=20,
        help="Minimum samples required per compound",
    )
    parser.add_argument(
        "--max_samples_per_compound",
        type=int,
        default=None,
        help="Maximum samples to use per compound",
    )
    parser.add_argument(
        "--unconditional_samples",
        type=int,
        default=2000,
        help="Number of samples for unconditional FID",
    )
    parser.add_argument(
        "--use_cellflux_generated",
        action="store_true",
        default=False,
        help="Use CellFlux generated images instead of REPA generated images (already 96x96, no cropping needed)",
    )
    parser.add_argument(
        "--config_path",
        type=str,
        default="/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml",
        help="Path to config file for datamodule",
    )

    args = parser.parse_args()

    seed_everything(args.seed)

    # Use separate seed for target selection if provided
    target_seed = (
        args.target_selection_seed
        if args.target_selection_seed is not None
        else args.seed
    )

    # Paths
    cellflux_path = "/mnt/pvc/CellFlux/images/cellflux/rxrx1"
    repa_data_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    real_data_path = "/mnt/pvc/IMPA_reproducibility/IMPA_reproducibility/datasets/rxrx1_extracted/rxrx1"
    metadata_path = "/mnt/pvc/IMPA_reproducibility/IMPA_reproducibility/datasets/rxrx1_extracted/rxrx1/metadata/rxrx1_df.csv"

    print(f"Using seed: {args.seed}")
    print(f"Cell line: {args.cell_line}")
    print(f"CellFlux images path: {cellflux_path}")
    print(f"REPA generated images path: {repa_data_path}")
    print(f"Real preprocessed data path: {real_data_path}")
    print(f"Metadata path: {metadata_path}")

    # Step 1: Get the 100 compound names from CellFlux directory
    cellflux_compounds = get_cellflux_compounds(cellflux_path)
    if len(cellflux_compounds) == 0:
        print("[error] No CellFlux compounds found. Exiting.")
        sys.exit(1)

    # Load configuration and create datamodule
    config = OmegaConf.load(args.config_path)
    datamodule = CellDataModule(config)

    # Load metadata
    metadata_df = pd.read_csv(metadata_path)
    print(f"Loaded metadata with {len(metadata_df)} rows")

    # Step 2: Filter CellFlux compounds to those that exist in metadata
    available_compounds = get_available_compounds(
        metadata_path, cellflux_compounds, args.cell_line
    )
    if len(available_compounds) == 0:
        print("[error] No available compounds found. Exiting.")
        sys.exit(1)

    # Step 3: Create mapping from compounds to siRNA IDs (only for the available compounds)
    compound_to_sirna_mapping = create_compound_to_sirna_mapping(
        datamodule, metadata_df, available_compounds, args.cell_line
    )

    if len(compound_to_sirna_mapping) == 0:
        print("[error] No compound-to-siRNA mappings found. Exiting.")
        sys.exit(1)

    print(
        f"Successfully mapped {len(compound_to_sirna_mapping)} compounds to siRNA IDs"
    )

    # Validate mapping completeness
    mapped_compounds = set(compound_to_sirna_mapping.keys())
    available_compounds_set = set(available_compounds)
    missing_compounds = available_compounds_set - mapped_compounds

    if missing_compounds:
        print(
            f"[warn] {len(missing_compounds)} compounds could not be mapped: {list(missing_compounds)[:5]}..."
        )

    # Count unique siRNA IDs
    unique_sirna_ids = set(
        mapping["sirna_id"] for mapping in compound_to_sirna_mapping.values()
    )
    print(
        f"Using {len(unique_sirna_ids)} unique siRNA IDs for {len(mapped_compounds)} compounds"
    )

    if len(unique_sirna_ids) == 100 and len(mapped_compounds) == 100:
        print("✓ Perfect mapping: 100 compounds → 100 unique siRNA IDs")
    elif len(unique_sirna_ids) != len(mapped_compounds):
        print(f"[info] Some siRNA IDs map to multiple compounds (many-to-one mapping)")

    # Sample compounds for evaluation (only from successfully mapped ones)
    # Sort compounds to ensure deterministic order across runs
    sorted_mapped_compounds = sorted(list(mapped_compounds))
    if len(sorted_mapped_compounds) > args.num_compounds:
        sampled_compounds = random.sample(sorted_mapped_compounds, args.num_compounds)
    else:
        sampled_compounds = sorted_mapped_compounds

    print(f"Analyzing {len(sampled_compounds)} compounds: {sampled_compounds[:10]}...")

    # Collect results
    conditional_results = []
    all_fixed_real_images = []  # Store preselected fixed real subsets
    all_generated_images = []

    print("\n" + "=" * 80)
    print("PRESELECTING FIXED REAL IMAGE SUBSETS")
    print("=" * 80)

    # First pass: collect and preselect fixed real image subsets for each compound
    compound_fixed_real_images = {}  # Store fixed subsets per compound

    for compound in tqdm(sampled_compounds, desc="Preselecting real images"):
        # Collect all real preprocessed images for this compound
        all_real_images = collect_real_preprocessed_images(
            compound,
            metadata_df,
            real_data_path,
            args.cell_line,
            max_samples=None,  # Don't limit here, we'll preselect properly
        )

        if len(all_real_images) < args.min_samples:
            print(
                f"[warn] Skipping {compound}: only {len(all_real_images)} real images, need {args.min_samples}"
            )
            continue

        # Preselect fixed subset
        fixed_real_subset = preselect_target_real_images(
            compound,
            all_real_images,
            args.min_samples,
            args.max_samples_per_compound,
            target_seed,  # Use target selection seed
        )

        compound_fixed_real_images[compound] = fixed_real_subset
        all_fixed_real_images.extend(fixed_real_subset)

    print(
        f"Preselected fixed real subsets for {len(compound_fixed_real_images)} compounds"
    )
    print(f"Total fixed real images: {len(all_fixed_real_images)}")

    print("\n" + "=" * 80)
    print("CALCULATING CONDITIONAL FID PER COMPOUND")
    print("=" * 80)

    for compound in tqdm(sampled_compounds, desc="Processing compounds"):
        if compound not in compound_fixed_real_images:
            continue  # Skip compounds that didn't meet minimum requirements

        # Get the mapping for this compound
        compound_mapping = compound_to_sirna_mapping[compound]
        sirna_id = compound_mapping["sirna_id"]

        # Get preselected fixed real images
        fixed_real_images = compound_fixed_real_images[compound]

        # Collect generated images (CellFlux PNG or REPA NPY)
        if args.use_cellflux_generated:
            generated_images = collect_cellflux_generated_images(
                compound,
                cellflux_path,
                max_samples=args.max_samples_per_compound,
            )
        else:
            generated_images = collect_repa_generated_images(
                sirna_id,
                repa_data_path,
                cell_line_code=0,
                max_samples=args.max_samples_per_compound,
            )

        # Calculate FID for this compound using fixed real subset
        result = calculate_fid_for_compound(
            compound,
            fixed_real_images,  # Use preselected fixed subset
            generated_images,
            args.min_samples,
            args.use_cellflux_generated,
        )

        if result is not None:
            conditional_results.append(result)
            # Add to global pools for unconditional FID
            all_generated_images.extend(generated_images)

    print(f"\nSuccessfully processed {len(conditional_results)} compounds")

    # Calculate unconditional FID
    print("\n" + "=" * 80)
    print("CALCULATING UNCONDITIONAL FID")
    print("=" * 80)

    unconditional_result = calculate_unconditional_fid(
        all_fixed_real_images,  # Use pooled fixed real subsets
        all_generated_images,
        args.unconditional_samples,
        args.use_cellflux_generated,
    )

    # Compile results
    results_df = pd.DataFrame(conditional_results)

    if not results_df.empty:
        # Add summary statistics
        avg_conditional_fid = float(results_df["real_vs_generated_fid"].mean())
        avg_conditional_real_vs_real_fid = float(results_df["real_vs_real_fid"].mean())

        summary_row = {
            "compound": "Average_Conditional",
            "num_real_total": float(results_df["num_real_total"].mean()),
            "num_real_per_half": float(results_df["num_real_per_half"].mean()),
            "num_generated_available": float(
                results_df["num_generated_available"].mean()
            ),
            "num_generated_used": float(results_df["num_generated_used"].mean()),
            "real_vs_generated_fid": avg_conditional_fid,
            "real_vs_real_fid": avg_conditional_real_vs_real_fid,
        }

        # Add summary and unconditional results
        all_results = conditional_results + [summary_row]
        if unconditional_result is not None:
            all_results.append(unconditional_result)

        results_df = pd.DataFrame(all_results)
    else:
        print("[warn] No valid conditional FID results")
        if unconditional_result is not None:
            results_df = pd.DataFrame([unconditional_result])

    # Save results
    if args.use_cellflux_generated:
        output_file = f"cellflux_preprocessed_hybrid_comparison_corrected_{args.cell_line}_seed_{args.seed}.csv"
    else:
        output_file = f"repa_preprocessed_hybrid_comparison_corrected_{args.cell_line}_seed_{args.seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print summary
    print("\n" + "=" * 80)
    print("SUMMARY OF RESULTS")
    print("=" * 80)

    if unconditional_result is not None:
        print(
            f"Unconditional FID (Real vs Generated): {unconditional_result['real_vs_generated_fid']:.4f}"
        )
        print(
            f"Unconditional FID (Real vs Real): {unconditional_result['real_vs_real_fid']:.4f}"
        )

    if not results_df.empty and "Average_Conditional" in results_df["compound"].values:
        avg_row = results_df[results_df["compound"] == "Average_Conditional"].iloc[0]
        print(
            f"Average Conditional FID (Real vs Generated): {avg_row['real_vs_generated_fid']:.4f}"
        )
        print(
            f"Average Conditional FID (Real vs Real): {avg_row['real_vs_real_fid']:.4f}"
        )
        print(f"Number of compounds evaluated: {len(conditional_results)}")

    print(
        f"\nProcessed {len(all_fixed_real_images)} fixed real images and {len(all_generated_images)} generated images total"
    )
    print("\nDetailed Results:")
    print(results_df.to_string(index=False))
