import argparse
import glob
import logging
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch_fidelity
import torchvision.transforms.functional as TF
from diffusers.models import AutoencoderKL
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything

# your modules
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch  # (unused here)
from sc_perturb.models.sit import SiT_models
from sc_perturb.utils.generation_utils import generate_perturbation_matched_samples
from sc_perturb.utils.utils import load_encoders
from tqdm import tqdm

# =========================
# Eval controls / heuristics
# =========================
NUCLEUS_CHANNEL = 0  # set this if your nucleus channel index differs
CROP_SIZE = 96
BORDER_MARGIN = CROP_SIZE // 2  # disallow centers too near borders
MIN_FG_FRAC = 0.01  # min fraction of nucleus-foreground pixels in crop
MIN_NUCLEUS_MEAN = 0.015  # min mean intensity of nucleus channel in crop
UPSCALE_TO_UINT8 = True  # convert to uint8 before feeding torch-fidelity
MAX_TRIALS_FACTOR = 50  # we allow up to factor * requested samples attempts


# =========================
# Small dataset for torch-fidelity
# =========================
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_4d):
        """
        data_4d: torch.Tensor [N, 3, H, W], dtype uint8 or float in [0,1]
        """
        self.data = data_4d

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]


# =========================
# I/O helpers
# =========================
def safe_load_npy(file_path: str) -> Optional[torch.Tensor]:
    """
    Load .npy as torch.FloatTensor [C,H,W] scaled to [0,1].
    Accepts HWC/CHW, 1–6 channels. Returns None on error.
    """
    try:
        arr = np.load(file_path)
        ten = torch.from_numpy(arr)
        if ten.ndim == 2:  # [H,W] -> [1,H,W]
            ten = ten.unsqueeze(0)
        elif ten.ndim == 3:
            # If last dim looks like channels, convert HWC->CHW
            if ten.shape[-1] in (1, 3, 4, 5, 6):
                ten = ten.permute(2, 0, 1).contiguous()
        else:
            print(f"[warn] unexpected shape {ten.shape} in {file_path}; skipping")
            return None

        ten = ten.float()
        # Normalize to [0,1] if needed
        mx = ten.max().item() if ten.numel() > 0 else 1.0
        if mx > 1.5:
            ten = ten / mx
        return ten.clamp(0, 1)
    except Exception as e:
        print(f"[warn] failed to load {file_path}: {e}")
        return None


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Files named like: p{pid}/*_c{cell_type_id}_sample*.npy
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)
    if not os.path.exists(pert_path):
        return []
    pattern = f"_c{cell_type_id}_sample"
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))
    return [f for f in npy_files if pattern in f]


def load_numpy_files(file_paths, max_samples=None) -> List[torch.Tensor]:
    """
    Load a (possibly random) subset of npy files; returns a LIST of [C,H,W] float [0,1].
    """
    paths = list(file_paths)
    if max_samples is not None and len(paths) > max_samples:
        paths = random.sample(paths, max_samples)
    out = []
    for file_path in tqdm(paths, desc="Loading numpy files"):
        x = safe_load_npy(file_path)
        if x is not None:
            out.append(x)
    return out


# =========================
# Nuclei-centered cropping
# =========================
def otsu_threshold_approx(x: torch.Tensor) -> float:
    """
    Approximate Otsu threshold for x in [0,1], x: [H,W]. CPU for stability.
    """
    xv = x.detach().cpu().flatten()
    if xv.numel() == 0:
        return 0.0
    hist = torch.histc(xv, bins=256, min=0.0, max=1.0)
    p = hist / hist.sum().clamp(min=1)
    omega = torch.cumsum(p, dim=0)
    mu = torch.cumsum(p * torch.arange(256, dtype=torch.float32), dim=0)
    mu_t = mu[-1]
    sigma_b2 = (mu_t * omega - mu) ** 2 / (omega * (1.0 - omega)).clamp(min=1e-8)
    sigma_b2[torch.isnan(sigma_b2)] = -1
    k = int(torch.argmax(sigma_b2).item())
    return k / 255.0


def pick_nucleus_center(mask: torch.Tensor, margin: int) -> Optional[Tuple[int, int]]:
    """
    mask: [H,W] {0,1}; returns (y,x) far from borders or None.
    """
    H, W = mask.shape
    mask = mask.clone()
    if margin > 0:
        mask[:margin, :] = 0
        mask[-margin:, :] = 0
        mask[:, :margin] = 0
        mask[:, -margin:] = 0
    coords = mask.nonzero(as_tuple=False)
    if coords.numel() == 0:
        return None
    yx = coords[torch.randint(0, coords.shape[0], (1,)).item()]
    return int(yx[0].item()), int(yx[1].item())


def crop_centered(
    img: torch.Tensor, center: Tuple[int, int], size: int
) -> torch.Tensor:
    """
    img: [C,H,W]; returns [C,size,size] with reflect padding if needed.
    """
    C, H, W = img.shape
    half = size // 2
    y, x = center
    y0 = max(0, y - half)
    x0 = max(0, x - half)
    y1 = min(H, y0 + size)
    x1 = min(W, x0 + size)
    crop = img[:, y0:y1, x0:x1]
    # pad if near borders
    pad_h = size - crop.shape[1]
    pad_w = size - crop.shape[2]
    if pad_h > 0 or pad_w > 0:
        crop = F.pad(crop, (0, pad_w, 0, pad_h), mode="reflect")
    return crop


def nuclei_centered_crop(
    img: torch.Tensor,
    nucleus_channel: int = NUCLEUS_CHANNEL,
    crop_size: int = CROP_SIZE,
    border_margin: int = BORDER_MARGIN,
    min_fg_frac: float = MIN_FG_FRAC,
    min_nucleus_mean: float = MIN_NUCLEUS_MEAN,
) -> Optional[torch.Tensor]:
    """
    Returns [C,crop_size,crop_size] if a valid nucleus-centered crop is found; else None.
    """
    assert img.ndim == 3, "Expected [C,H,W]"
    nuc = img[nucleus_channel].clamp(0, 1)
    thr = otsu_threshold_approx(nuc)
    mask = nuc >= thr

    center = pick_nucleus_center(mask, border_margin)
    if center is None:
        return None

    crop = crop_centered(img, center, crop_size)
    nuc_crop = crop[nucleus_channel]
    fg_frac = (nuc_crop >= thr).float().mean().item()
    if fg_frac < min_fg_frac:
        return None
    if nuc_crop.mean().item() < min_nucleus_mean:
        return None
    return crop


def sample_nuclei_crops(
    images: List[torch.Tensor],
    num_samples: int,
    max_trials_factor: int = MAX_TRIALS_FACTOR,
) -> torch.Tensor:
    """
    Keep sampling with replacement until we collect EXACTLY num_samples valid nucleus-centered crops,
    or we exhaust (num_samples * max_trials_factor) attempts.
    Returns [num_samples, C, crop_size, crop_size] if filled; else fewer with a warning.
    """
    if len(images) == 0:
        raise ValueError("Empty image pool.")
    crops = []
    attempts = 0
    max_attempts = num_samples * max_trials_factor
    while len(crops) < num_samples and attempts < max_attempts:
        attempts += 1
        img = images[random.randrange(len(images))]
        crop = nuclei_centered_crop(img)
        if crop is not None:
            crops.append(crop)
    if len(crops) < num_samples:
        print(
            f"[warn] Only collected {len(crops)}/{num_samples} valid crops after {attempts} attempts."
        )
    return torch.stack(crops) if len(crops) > 0 else torch.empty(0)


# =========================
# RGB conversion (robust wrapper around your to_rgb)
# =========================
def to_rgb_batch(x: torch.Tensor) -> torch.Tensor:
    """
    x: [N,C,H,W] or [C,H,W]; uses your to_rgb per-sample safely.
    Returns [N,3,H,W] float in [0,1].
    """
    if x.ndim == 3:
        rgb = to_rgb(x[None].cpu()).squeeze(0)  # [3,H,W]
        return rgb[None]
    elif x.ndim == 4:
        rgbs = []
        for i in range(x.shape[0]):
            rgb = to_rgb(x[i][None].cpu()).squeeze(0)  # [3,H,W]
            rgbs.append(rgb)
        return torch.stack(rgbs, 0)
    else:
        raise ValueError(f"Unexpected tensor shape for to_rgb_batch: {x.shape}")


def to_eval_rgb_uint8(batch_chw: torch.Tensor) -> torch.Tensor:
    """
    batch_chw: [N,C,H,W] float [0,1] -> [N,3,H,W] uint8
    """
    rgb = to_rgb_batch(batch_chw)  # float [0,1]
    if UPSCALE_TO_UINT8:
        rgb = (rgb.clamp(0, 1) * 255.0).to(torch.uint8)
    return rgb


# =========================
# Collect images (real / generated)
# =========================
def collect_images_for_perturbations(
    datamodule,
    perturbation_ids,
    allowed_cell_types,
    is_generated=False,
    generated_path=None,
    model=None,
    vae=None,
    latent_size=None,
    resolution=None,
    latents_bias=None,
    latents_scale=None,
    path_type=None,
    device=None,
):
    """
    Returns (all_images: List[Tensor[C,H,W]], mapping: {pid: (start, end)})
    """
    all_images = []
    perturbation_image_mapping = {}

    for pert_id in tqdm(
        perturbation_ids,
        desc=f"Collecting {'generated' if is_generated else 'real'} images",
    ):
        pert_images = []

        for cell_type_id in allowed_cell_types:
            if not is_generated:
                ds = datamodule.filter_samples(
                    perturbation_id=pert_id, cell_type_id=cell_type_id
                )
                if ds is not None and len(ds) > 0:
                    for i in range(len(ds)):
                        x = ds[i][0]
                        if isinstance(x, torch.Tensor) and x.ndim == 3:
                            pert_images.append(x.float().clamp(0, 1).cpu())
            else:
                if generated_path is not None:
                    files = find_generated_files_by_perturbation_and_celltype(
                        generated_path, pert_id, cell_type_id
                    )
                    if files:
                        imgs = load_numpy_files(files, max_samples=None)
                        pert_images.extend(imgs)
                else:
                    # on-the-fly generation (kept here if you need it later)
                    if model is not None:
                        raise NotImplementedError(
                            "Manual generation omitted here for brevity."
                        )

        if len(pert_images) > 0:
            perturbation_image_mapping[pert_id] = (
                len(all_images),
                len(all_images) + len(pert_images),
            )
            all_images.extend(pert_images)
        else:
            print(
                f"[info] No {'generated' if is_generated else 'real'} images for perturbation {pert_id}"
            )

    return all_images, perturbation_image_mapping


def save_sample_crops(real_crops, gen_crops, pert_id, output_dir="crop_samples"):
    """
    Save 8 real and 8 generated crops for visual inspection.

    Args:
        real_crops: [N, C, H, W] tensor of real crops
        gen_crops: [N, C, H, W] tensor of generated crops
        pert_id: perturbation ID for naming
        output_dir: directory to save images
    """
    os.makedirs(output_dir, exist_ok=True)

    # Take up to 8 samples from each
    n_real = min(8, real_crops.shape[0])
    n_gen = min(8, gen_crops.shape[0])

    # Convert to RGB and ensure proper range
    real_rgb = to_rgb_batch(real_crops[:n_real])  # [n_real, 3, H, W]
    gen_rgb = to_rgb_batch(gen_crops[:n_gen])  # [n_gen, 3, H, W]

    # Clamp to [0, 1] range and convert to uint8
    real_rgb = (real_rgb.clamp(0, 1) * 255).to(torch.uint8)
    gen_rgb = (gen_rgb.clamp(0, 1) * 255).to(torch.uint8)

    # Save individual images
    for i in range(n_real):
        real_img = TF.to_pil_image(real_rgb[i])
        real_path = os.path.join(output_dir, f"pert_{pert_id}_real_{i:02d}.png")
        real_img.save(real_path)

    for i in range(n_gen):
        gen_img = TF.to_pil_image(gen_rgb[i])
        gen_path = os.path.join(output_dir, f"pert_{pert_id}_gen_{i:02d}.png")
        gen_img.save(gen_path)

    print(
        f"Saved {n_real} real and {n_gen} generated crops for perturbation {pert_id} to {output_dir}"
    )


# =========================
# Main
# =========================
if __name__ == "__main__":
    seed = 77
    MANUAL_GENERATION = False
    NUM_PERTURBATIONS = 100
    NUM_SAMPLES = 5000
    ALLOWED_CELL_TYPES = [3]  # HUVEC only

    seed_everything(seed)
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

    # Paths
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"

    # Data
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Sample random perturbations
    all_perturbation_ids = list(range(1, 1139))  # 1..1138
    sampled_perturbation_ids = random.sample(all_perturbation_ids, NUM_PERTURBATIONS)
    print(
        f"Sampled {NUM_PERTURBATIONS} perturbation IDs: {sampled_perturbation_ids[:10]}..."
    )
    print(f"Allowed cell types: {ALLOWED_CELL_TYPES}")

    # Collect real
    print("\n" + "=" * 80)
    print("COLLECTING REAL IMAGES")
    print("=" * 80)
    all_real_images, real_pert_mapping = collect_images_for_perturbations(
        datamodule,
        sampled_perturbation_ids,
        ALLOWED_CELL_TYPES,
        is_generated=False,
    )
    print(f"Collected {len(all_real_images)} real images total")

    # Collect generated
    print("\n" + "=" * 80)
    print("COLLECTING GENERATED IMAGES")
    print("=" * 80)
    if MANUAL_GENERATION:
        raise NotImplementedError(
            "Manual generation path not included in this snippet."
        )
    else:
        all_generated_images, generated_pert_mapping = collect_images_for_perturbations(
            datamodule,
            sampled_perturbation_ids,
            ALLOWED_CELL_TYPES,
            is_generated=True,
            generated_path=generated_path,
        )
    print(f"Collected {len(all_generated_images)} generated images total")

    # =========================
    # Unconditional (FIDo): nuclei-centered crops pooled across perts
    # =========================
    print("\n" + "=" * 80)
    print("SAMPLING NUCLEI-CENTERED CROPS (UNCONDITIONAL)")
    print("=" * 80)

    # Build big pools only from perts that exist on both sides (keeps symmetry tight even if you say all exist)
    valid_perts = [
        pid
        for pid in sampled_perturbation_ids
        if pid in real_pert_mapping and pid in generated_pert_mapping
    ]
    if len(valid_perts) != NUM_PERTURBATIONS:
        print(
            f"[warn] {NUM_PERTURBATIONS - len(valid_perts)} perturbations missing on one side."
        )
        exit()
    real_pool, gen_pool = [], []
    for pid in valid_perts:
        rs, re = real_pert_mapping[pid]
        gs, ge = generated_pert_mapping[pid]
        real_pool.extend(all_real_images[rs:re])
        gen_pool.extend(all_generated_images[gs:ge])

    # Collect exactly NUM_SAMPLES valid crops per side (or warn if not enough)
    sampled_real_images = sample_nuclei_crops(real_pool, NUM_SAMPLES)
    sampled_generated_images = sample_nuclei_crops(gen_pool, NUM_SAMPLES)

    print(f"Real crops shape: {tuple(sampled_real_images.shape)}")
    print(f"Gen  crops shape: {tuple(sampled_generated_images.shape)}")

    # For real vs real comparison, split real samples in half
    half_samples = NUM_SAMPLES // 2
    sampled_real_images_half1 = sampled_real_images[:half_samples]
    sampled_real_images_half2 = sampled_real_images[half_samples : half_samples * 2]
    sampled_generated_images_half = sampled_generated_images[:half_samples]

    print("\n" + "=" * 80)
    print("CALCULATING UNCONDITIONAL FID")
    print("=" * 80)

    # Real vs Generated FID (using half samples for fair comparison)
    real_uint8_half = to_eval_rgb_uint8(sampled_real_images_half1)
    fake_uint8_half = to_eval_rgb_uint8(sampled_generated_images_half)

    real_images_dataset_half = CustomDataset(real_uint8_half)
    generated_images_dataset_half = CustomDataset(fake_uint8_half)

    unconditional_metrics = torch_fidelity.calculate_metrics(
        input1=real_images_dataset_half,
        input2=generated_images_dataset_half,
        cuda=torch.cuda.is_available(),
        fid=True,
        kid=False,
        verbose=False,
    )
    unconditional_fid = float(unconditional_metrics["frechet_inception_distance"])
    print(f"Unconditional FID (Real vs Generated): {unconditional_fid:.4f}")

    # Real vs Real FID (lower bound)
    real_uint8_half1 = to_eval_rgb_uint8(sampled_real_images_half1)
    real_uint8_half2 = to_eval_rgb_uint8(sampled_real_images_half2)

    real_images_dataset_half1 = CustomDataset(real_uint8_half1)
    real_images_dataset_half2 = CustomDataset(real_uint8_half2)

    real_vs_real_metrics = torch_fidelity.calculate_metrics(
        input1=real_images_dataset_half1,
        input2=real_images_dataset_half2,
        cuda=torch.cuda.is_available(),
        fid=True,
        kid=False,
        verbose=False,
    )
    unconditional_real_vs_real_fid = float(
        real_vs_real_metrics["frechet_inception_distance"]
    )
    print(f"Unconditional FID (Real vs Real): {unconditional_real_vs_real_fid:.4f}")

    # =========================
    # Conditional (FIDc): per-pert FID averaged
    # =========================
    print("\n" + "=" * 80)
    print("CALCULATING CONDITIONAL FID PER PERTURBATION")
    print("=" * 80)
    conditional_results = []

    for pert_id in tqdm(sampled_perturbation_ids, desc="Per-perturbation FID"):
        if pert_id not in real_pert_mapping or pert_id not in generated_pert_mapping:
            continue

        r_s, r_e = real_pert_mapping[pert_id]
        g_s, g_e = generated_pert_mapping[pert_id]
        real_imgs_p = all_real_images[r_s:r_e]
        gen_imgs_p = all_generated_images[g_s:g_e]
        if len(real_imgs_p) == 0 or len(gen_imgs_p) == 0:
            continue

        # Target 100 crops per perturbation (we know from inspection this works well)
        # Then split 50-50 for real vs real comparison
        n_target = 100

        # For fair comparison, use half samples for real vs real
        n_target_half = n_target // 2  # = 50

        # Sample crops: 100 for real (split into 50+50), 50 for generated
        real_pert_cropped_full = sample_nuclei_crops(real_imgs_p, n_target)
        gen_pert_cropped = sample_nuclei_crops(gen_imgs_p, n_target_half)

        # Check if we got the expected number of crops (should be 100 real, 50 generated)
        if real_pert_cropped_full.shape[0] < 32:
            print(
                f"[warn] skipping pert {pert_id}: got {real_pert_cropped_full.shape[0]} real crops, expected {n_target}."
            )
            continue
        if gen_pert_cropped.shape[0] < 32:
            print(
                f"[warn] skipping pert {pert_id}: got {gen_pert_cropped.shape[0]} gen crops, expected {n_target_half}."
            )
            continue

        # Split real crops for real vs real comparison (50 + 50)
        real_pert_cropped_half1 = real_pert_cropped_full[:n_target_half]
        real_pert_cropped_half2 = real_pert_cropped_full[n_target_half:n_target]

        # Save sample crops for visual inspection
        save_sample_crops(real_pert_cropped_half1, gen_pert_cropped, pert_id)
        exit()

        # Real vs Generated FID
        real_pert_uint8 = to_eval_rgb_uint8(real_pert_cropped_half1)
        fake_pert_uint8 = to_eval_rgb_uint8(gen_pert_cropped)

        real_pert_dataset = CustomDataset(real_pert_uint8)
        gen_pert_dataset = CustomDataset(fake_pert_uint8)

        # Real vs Real FID
        real_pert_uint8_half1 = to_eval_rgb_uint8(real_pert_cropped_half1)
        real_pert_uint8_half2 = to_eval_rgb_uint8(real_pert_cropped_half2)

        real_pert_dataset_half1 = CustomDataset(real_pert_uint8_half1)
        real_pert_dataset_half2 = CustomDataset(real_pert_uint8_half2)

        try:
            # Calculate Real vs Generated FID
            pert_metrics = torch_fidelity.calculate_metrics(
                input1=real_pert_dataset,
                input2=gen_pert_dataset,
                cuda=torch.cuda.is_available(),
                fid=True,
                kid=False,
                verbose=False,
            )
            conditional_fid = float(pert_metrics["frechet_inception_distance"])

            # Calculate Real vs Real FID
            real_vs_real_metrics = torch_fidelity.calculate_metrics(
                input1=real_pert_dataset_half1,
                input2=real_pert_dataset_half2,
                cuda=torch.cuda.is_available(),
                fid=True,
                kid=False,
                verbose=False,
            )
            conditional_real_vs_real_fid = float(
                real_vs_real_metrics["frechet_inception_distance"]
            )

            conditional_results.append(
                {
                    "perturbation_id": pert_id,
                    "num_real": len(real_imgs_p),
                    "num_generated": len(gen_imgs_p),
                    "num_samples_used": n_target_half,
                    "conditional_fid": conditional_fid,
                    "conditional_real_vs_real_fid": conditional_real_vs_real_fid,
                }
            )
            print(
                f"Pert {pert_id}: FID={conditional_fid:.4f}, Real vs Real={conditional_real_vs_real_fid:.4f} (n={n_target_half})"
            )
        except Exception as e:
            print(f"[warn] FID failed for pert {pert_id}: {e}")

    # =========================
    # Save & summarize
    # =========================
    results_df = pd.DataFrame(conditional_results)

    if not results_df.empty:
        avg_conditional_fid = float(results_df["conditional_fid"].mean())
        avg_conditional_real_vs_real_fid = float(
            results_df["conditional_real_vs_real_fid"].mean()
        )
        summary_rows = [
            {
                "perturbation_id": "Average(FIDc)",
                "num_real": float(results_df["num_real"].mean()),
                "num_generated": float(results_df["num_generated"].mean()),
                "num_samples_used": float(results_df["num_samples_used"].mean()),
                "conditional_fid": avg_conditional_fid,
                "conditional_real_vs_real_fid": avg_conditional_real_vs_real_fid,
            },
            {
                "perturbation_id": "Unconditional(FIDo)",
                "num_real": half_samples,
                "num_generated": half_samples,
                "num_samples_used": half_samples,
                "conditional_fid": unconditional_fid,
                "conditional_real_vs_real_fid": unconditional_real_vs_real_fid,
            },
        ]
        results_df = pd.concat(
            [results_df, pd.DataFrame(summary_rows)], ignore_index=True
        )
    else:
        print("[warn] No valid per-perturbation FIDs computed.")
        # still append unconditional row for completeness
        results_df = pd.DataFrame(
            [
                {
                    "perturbation_id": "Unconditional(FIDo)",
                    "num_real": half_samples,
                    "num_generated": half_samples,
                    "num_samples_used": half_samples,
                    "conditional_fid": unconditional_fid,
                    "conditional_real_vs_real_fid": unconditional_real_vs_real_fid,
                }
            ]
        )

    output_file = f"perturbation_type_cellflux_comparison_results_seed_{seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    print("\n" + "=" * 80)
    print("SUMMARY OF RESULTS")
    print("=" * 80)
    print(f"Unconditional FID (Real vs Generated): {unconditional_fid:.4f}")
    print(f"Unconditional FID (Real vs Real): {unconditional_real_vs_real_fid:.4f}")
    if "Average(FIDc)" in results_df["perturbation_id"].values:
        avg_row = results_df.loc[results_df["perturbation_id"] == "Average(FIDc)"]
        print(
            f"Average Conditional FID (Real vs Generated): {avg_row['conditional_fid'].iloc[0]:.4f}"
        )
        print(
            f"Average Conditional FID (Real vs Real): {avg_row['conditional_real_vs_real_fid'].iloc[0]:.4f}"
        )
        print(
            f"Number of perturbations evaluated: {sum(pd.to_numeric(results_df['perturbation_id'], errors='coerce').notna())}"
        )
    print("\nDetailed Results:")
    print(results_df.to_string(index=False))
