import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from dataset import CellDataModule, to_rgb
from diffusers.models import AutoencoderKL
from metrics_utils import calculate_metrics_from_scratch
from models.sit import SiT_models
from omegaconf import OmegaConf
from openphenom import OpenPhenomEncoder
from pytorch_lightning import seed_everything
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from train import generate_perturbation_matched_samples
from utils import load_encoders

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.

    Args:
        generated_path: Path to the directory containing generated data
        perturbation_id: Perturbation ID to filter by
        cell_type_id: Cell type ID to filter by (0-3)

    Returns:
        List of file paths matching both the perturbation and cell type
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def load_numpy_files(file_paths, max_samples):
    """
    Load a random subset of numpy files into a torch tensor.

    Args:
        file_paths: List of numpy file paths to load
        max_samples: Maximum number of samples to load

    Returns:
        Torch tensor containing the loaded data
    """
    # Randomly sample file paths if there are more than max_samples
    if len(file_paths) > max_samples:
        file_paths = random.sample(file_paths, max_samples)

    # Load the numpy files
    data = []
    for file_path in tqdm(file_paths, desc="Loading numpy files"):
        try:
            img = np.load(file_path)
            data.append(torch.from_numpy(img).float())
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Stack into a single tensor
    return torch.stack(data) if data else None


def create_cell_type_metadata(num_samples=500, perturbation_id=1138, cell_type=1):
    """
    Creates a perturbation metadata list with all cell types set to 1.

    Args:
        num_samples: Number of metadata entries to create
        perturbation_id: The perturbation ID to use for all entries

    Returns:
        List of metadata dictionaries with cell_type_id set to 1
    """
    perturbation_metadata = []

    for i in range(num_samples):
        metadata_entry = {
            "perturbation_id": perturbation_id,
            "cell_type_id": 1,  # All cell types set to 1 as requested
            "is_generated": False,  # This is typically False for real data
        }
        perturbation_metadata.append(metadata_entry)

    print(f"Created perturbation metadata with {len(perturbation_metadata)} entries")
    print(f"All entries have cell_type_id set to 1")

    return perturbation_metadata


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, , or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


def plot_pca_comparison(
    features1,
    features2,
    label1_desc,
    label2_desc,
    title_pert_id1,
    title_pert_id2,
    plot_title_prefix,
    output_filename_prefix,
    base_save_path="ophenom_qual_results/pca_comparisons",
    title_fontsize=26,  # Added
    label_fontsize=22,  # Added
    legend_fontsize=20,  # Added
    tick_fontsize=18,  # Added
):
    """
    Generates and saves a PCA plot comparing two sets of features.
    """
    os.makedirs(base_save_path, exist_ok=True)

    combined_features = np.vstack((features1, features2))
    # Create labels for the plot legend
    labels_plot = np.array(
        [f"{label1_desc} (p{title_pert_id1})"] * len(features1)
        + [f"{label2_desc} (p{title_pert_id2})"] * len(features2)
    )

    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(combined_features)

    pca = PCA(n_components=2)
    embedding = pca.fit_transform(scaled_features)
    var_explained = pca.explained_variance_ratio_ * 100

    plot_df = pd.DataFrame(
        {
            "PC1": embedding[:, 0],
            "PC2": embedding[:, 1],
            "Type": labels_plot,  # Use the descriptive labels for hue
        }
    )

    plt.figure(figsize=(12, 10))
    # Dynamically create palette based on unique labels
    unique_plot_labels = plot_df["Type"].unique()
    palette = {ul: sns.color_palette()[i] for i, ul in enumerate(unique_plot_labels)}

    sns.scatterplot(
        data=plot_df,
        x="PC1",
        y="PC2",
        hue="Type",  # Hue is based on the combined descriptive label
        palette=palette,
        alpha=0.7,
        s=50,
        edgecolor=None,
    )

    plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)", fontsize=label_fontsize)  # Modified
    plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)", fontsize=label_fontsize)  # Modified
    plt.title(
        f"PCA: {plot_title_prefix} - p{title_pert_id1} vs p{title_pert_id2}",
        fontsize=title_fontsize,
    )  # Modified
    plt.legend(loc="upper right", fontsize=legend_fontsize)  # Modified
    plt.xticks(fontsize=tick_fontsize)  # Added
    plt.yticks(fontsize=tick_fontsize)  # Added
    # plt.tight_layout()

    output_file = os.path.join(
        base_save_path,
        f"{output_filename_prefix}_p{title_pert_id1}_vs_p{title_pert_id2}.pdf",
    )
    plt.savefig(output_file, dpi=600)
    plt.close()
    print(f"Saved PCA comparison plot to {output_file}")


def plot_pca_dual_perturbation_comparison(
    real_features_ref,
    generated_features_ref,
    real_features_other,
    generated_features_other,
    ref_pert_id,
    other_pert_id,
    color_map,  # Added
    type_markers_map,  # Added
    base_save_path="ophenom_qual_results/pca_comparisons_dual",
    title_fontsize=26,  # Added
    label_fontsize=22,  # Added
    legend_fontsize=20,  # Added
    legend_title_fontsize=18,  # Added
    tick_fontsize=18,  # Added
):
    """
    Generates and saves a PCA plot comparing real/generated features of two perturbations.
    Hue by PerturbationID, Style by Type (Real/Generated).
    """
    os.makedirs(base_save_path, exist_ok=True)

    # Combine all features
    features_list = []
    pert_labels_list = []
    type_labels_list = []

    # Reference perturbation data
    if real_features_ref is not None and real_features_ref.size > 0:
        features_list.append(real_features_ref)
        pert_labels_list.extend([f"p{ref_pert_id}"] * len(real_features_ref))
        type_labels_list.extend(["Real"] * len(real_features_ref))
    if generated_features_ref is not None and generated_features_ref.size > 0:
        features_list.append(generated_features_ref)
        pert_labels_list.extend([f"p{ref_pert_id}"] * len(generated_features_ref))
        type_labels_list.extend(["Generated"] * len(generated_features_ref))

    # Other perturbation data
    if real_features_other is not None and real_features_other.size > 0:
        features_list.append(real_features_other)
        pert_labels_list.extend([f"p{other_pert_id}"] * len(real_features_other))
        type_labels_list.extend(["Real"] * len(real_features_other))
    if generated_features_other is not None and generated_features_other.size > 0:
        features_list.append(generated_features_other)
        pert_labels_list.extend([f"p{other_pert_id}"] * len(generated_features_other))
        type_labels_list.extend(["Generated"] * len(generated_features_other))

    if not features_list:
        print(
            f"No features provided for PCA dual comparison between p{ref_pert_id} and p{other_pert_id}. Skipping."
        )
        return

    combined_features = np.vstack(features_list)

    if combined_features.shape[0] < 2:
        print(
            f"Not enough samples ({combined_features.shape[0]}) for PCA dual comparison between p{ref_pert_id} and p{other_pert_id}. Skipping."
        )
        return

    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(combined_features)

    pca = PCA(n_components=2)
    embedding = pca.fit_transform(scaled_features)
    var_explained = pca.explained_variance_ratio_ * 100

    plot_df = pd.DataFrame(
        {
            "PC1": embedding[:, 0],
            "PC2": embedding[:, 1],
            "PerturbationID": pert_labels_list,
            "Type": type_labels_list,
        }
    )

    plt.figure(figsize=(12, 10))

    # Use the passed-in color_map and type_markers_map
    current_perturbations = [f"p{ref_pert_id}", f"p{other_pert_id}"]
    active_color_palette = {
        p: color_map.get(p) for p in current_perturbations if p in color_map
    }

    sns.scatterplot(
        data=plot_df,
        x="PC1",
        y="PC2",
        hue="PerturbationID",
        style="Type",
        palette=active_color_palette,  # Use filtered global color map
        markers=type_markers_map,  # Use global type markers
        alpha=0.7,
        s=90,  # Increased marker size
        edgecolor=None,
    )

    plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)", fontsize=label_fontsize)  # Modified
    plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)", fontsize=label_fontsize)  # Modified
    plt.title(
        f"PCA: p{ref_pert_id} (Real/Gen) vs p{other_pert_id} (Real/Gen)",
        fontsize=title_fontsize,
    )  # Modified
    plt.legend(
        loc="upper right",
        fontsize=legend_fontsize,
        title_fontsize=legend_title_fontsize,
    )  # Modified
    plt.xticks(fontsize=tick_fontsize)  # Added
    plt.yticks(fontsize=tick_fontsize)  # Added
    # plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout for legend

    output_file = os.path.join(
        base_save_path, f"pca_dual_p{ref_pert_id}_vs_p{other_pert_id}.pdf"
    )
    plt.savefig(output_file, dpi=600)
    plt.close()
    print(f"Saved dual PCA comparison plot to {output_file}")


if __name__ == "__main__":
    seed = 0
    MANUAL_GENERATION = False
    MODEL_CHECKPOINT_PATH = "/mnt/pvc/REPA/exps/OOD_ct1_p1137_ophenomdeneme-b-enc8-in512/checkpoints/min_AVG_FID_78.83929616.pt" # Example
    ENC_TYPE = "openphenom-vit-b" # Encoder type for SiT model features
    RESOLUTION = 512 # Image resolution
    # cell_type_id is used for filtering real data, CELL_TYPE_IDS_TO_PROCESS for generated
    cell_type_id = 1 # Default cell type for filtering real data via datamodule
    seed_everything(seed)
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # generated_path = "/mnt/pvc/REPA/generated_ood_128"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # --- Model Loading for Manual Generation (if enabled) ---
    model_for_generation = None
    vae_for_generation = None
    latents_scale_for_generation = None
    latents_bias_for_generation = None
    path_type_for_generation = "linear"
    latent_size_for_generation = None

    if MANUAL_GENERATION:
        print("MANUAL_GENERATION is True. Loading model for image generation...")
        if not os.path.exists(MODEL_CHECKPOINT_PATH):
            # Try to find a model checkpoint if the default path doesn't exist (optional)
            # For now, strictly require MODEL_CHECKPOINT_PATH to be valid
            print(f"Error: Model checkpoint not found at {MODEL_CHECKPOINT_PATH}")
            print("Please set MODEL_CHECKPOINT_PATH correctly.")
            sys.exit(1) # Exit if model path is not found

        ckpt = torch.load(MODEL_CHECKPOINT_PATH, map_location="cpu") # weights_only=False is default

        latent_size_for_generation = RESOLUTION // 8
        
        # Load encoders for SiT model's z_dims
        # Assuming 'cpu' for loading these components initially, will be moved to device with model
        encoders_for_sit, _, _ = load_encoders(ENC_TYPE, 'cpu', RESOLUTION) 
        z_dims_for_sit = [encoder.embed_dim for encoder in encoders_for_sit] if ENC_TYPE != "None" else [0]
        block_kwargs_for_sit = {"fused_attn": True, "qk_norm": False}

        model_for_generation = SiT_models["SiT-XL/2"](
            input_size=latent_size_for_generation,
            num_classes=1139, # Number of perturbation classes
            use_cfg=True,
            z_dims=z_dims_for_sit,
            encoder_depth=8, 
            in_channels=24, 
            **block_kwargs_for_sit,
        )
        model_for_generation.load_state_dict(ckpt["model"])
        model_for_generation = model_for_generation.to(device)
        model_for_generation.eval()

        vae_for_generation = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
        latents_scale_for_generation = (
            torch.tensor([0.18215, 0.18215, 0.18215, 0.18215])
            .view(1, 4, 1, 1)
            .to(device)
        )
        latents_bias_for_generation = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1).to(device)
        print("Model, VAE, and generation parameters loaded successfully.")
    # --- End Model Loading ---

    # Define font sizes for plots
    TITLE_FONTSIZE = 26
    LABEL_FONTSIZE = 22
    LEGEND_FONTSIZE = 20
    LEGEND_TITLE_FONTSIZE = 18
    TICK_FONTSIZE = 18

    # Sample 4 random perturbation IDs out of 1138
    # all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    # sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    BATCH_SIZE_FEATURES = 64  # Define batch size for feature extraction
    encoder = OpenPhenomEncoder().to(device).eval()
    # Iterate through each perturbation ID and calculate metrics
    results = []
    real_features_1138 = None
    generated_features_1138 = None
    REFERENCE_PERT_ID = 1138

    # --- Create Global Color Map and Marker Map for Consistent Plotting ---
    # Define GLOBAL_PERTURBATION_COLOR_MAP as it's used by the aggregated PCA plot
    GLOBAL_PERTURBATION_COLOR_MAP = {}
    # Assuming sampled_perturbation_ids contains numeric IDs
    _unique_numeric_ids_sorted_for_map = sorted(list(set(sampled_perturbation_ids)))
    _color_palette_list_for_map = sns.color_palette()
    _color_idx_non_ref_for_map = 0

    for _pid_numeric in _unique_numeric_ids_sorted_for_map:
        _pid_str = f"p{_pid_numeric}"
        if _pid_numeric == REFERENCE_PERT_ID:
            GLOBAL_PERTURBATION_COLOR_MAP[_pid_str] = "gray"
        else:
            GLOBAL_PERTURBATION_COLOR_MAP[_pid_str] = _color_palette_list_for_map[
                _color_idx_non_ref_for_map % len(_color_palette_list_for_map)
            ]
            _color_idx_non_ref_for_map += 1
    print(
        f"Global Perturbation Color Map (for aggregated plot): {GLOBAL_PERTURBATION_COLOR_MAP}"
    )

    # Define cell types and their colors (adjust if your cell type IDs or number differ)
    # CELL_TYPE_IDS_TO_PROCESS = [
    #     0,
    #     1,
    #     2,
    #     3,
    # ]  # Define which cell types to look for in generated data
    CELL_TYPE_IDS_TO_PROCESS = [
        1,
    ]  # Define which cell types to look for in generated data
    GLOBAL_CELL_TYPE_COLOR_MAP = {
        0: "purple",
        1: "orange",
        2: "green",
        3: "brown",
        # Add more if needed, or make it dynamic if cell IDs are not fixed
    }
    GLOBAL_TYPE_MARKERS = {
        "Real": "o",
        "Generated": "X",
    }  # Re-define for aggregated plot
    REAL_GENERATED_COLOR_MAP = {"Real": "blue", "Generated": "red"}
    # GLOBAL_CELL_TYPE_MARKER_MAP = { ... } # No longer needed if style is removed for this plot type

    # print(f"Global Perturbation Color Map: {GLOBAL_PERTURBATION_COLOR_MAP}")
    print(
        f"Global Type Markers: {GLOBAL_TYPE_MARKERS}"
    )  # Uncommented print for clarity
    print(f"Global Cell Type Color Map: {GLOBAL_CELL_TYPE_COLOR_MAP}")
    print(f"Real/Generated Color Map: {REAL_GENERATED_COLOR_MAP}")
    # --- End Global Map Creation ---

    all_real_features_list = []
    all_generated_features_list = []
    all_pert_ids_for_real = []
    all_pert_ids_for_generated = []
    all_cell_types_for_real = []  # Added for aggregated PCA by cell type
    all_cell_types_for_generated = []  # Added for aggregated PCA by cell type
    ate_scores_list = []  # To store ATE scores

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id,
            cell_type_id=cell_type_id,  # Get all cell types for this perturbation
        )

        real_images_tensor = None
        real_images_cell_type_ids = np.array([])  # Initialize as empty numpy array

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
        else:
            real_images = [
                real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
            ]
            # Assuming real_filtered_dataset[i][2] is the cell_type_id
            real_images_cell_type_ids = np.array(
                [real_filtered_dataset[i][2] for i in range(len(real_filtered_dataset))]
            )
            print(
                f"Found {len(real_images)} real images for perturbation ID {pert_id} with corresponding cell types."
            )
            real_images_tensor = torch.stack(real_images)
            # Safety check for length consistency
            if len(real_images_tensor) != len(real_images_cell_type_ids):
                print(
                    f"Warning: Mismatch in real images ({len(real_images_tensor)}) and cell types ({len(real_images_cell_type_ids)}) for p{pert_id}. Truncating to shorter length."
                )
                min_len = min(len(real_images_tensor), len(real_images_cell_type_ids))
                real_images_tensor = real_images_tensor[:min_len]
                real_images_cell_type_ids = real_images_cell_type_ids[:min_len]

        # --- Generated Data Handling ---
        generated_images_tensor = None
        generated_images_cell_type_ids = np.array([])

        if MANUAL_GENERATION:
            print(f"Manually generating images for perturbation ID {pert_id}...")
            temp_generated_images_list = []
            temp_generated_cell_type_ids_list = []
            
            num_cell_types_to_gen_for = len(CELL_TYPE_IDS_TO_PROCESS)
            if num_cell_types_to_gen_for == 0:
                print(f"  Warning: CELL_TYPE_IDS_TO_PROCESS is empty for p{pert_id}. No images will be generated.")
            else:
                # Distribute NUM_SAMPLES among the cell types to process
                samples_per_ct_target = NUM_SAMPLES // num_cell_types_to_gen_for
                if samples_per_ct_target == 0 and NUM_SAMPLES > 0:
                    samples_per_ct_target = 1 # Ensure at least 1 sample if NUM_SAMPLES > 0 and CTTIP not empty
                
                generation_batch_size = 16 # Batch size for generation calls

                for ct_id_to_generate in CELL_TYPE_IDS_TO_PROCESS:
                    print(f"  Generating up to {samples_per_ct_target} samples for cell type {ct_id_to_generate}...")
                    num_batches_to_generate = (samples_per_ct_target + generation_batch_size - 1) // generation_batch_size
                    
                    generated_count_for_this_ct = 0
                    for batch_idx in tqdm(range(num_batches_to_generate), desc=f"    Generating p{pert_id}, c{ct_id_to_generate}", leave=False):
                        current_batch_size = min(generation_batch_size, samples_per_ct_target - generated_count_for_this_ct)
                        if current_batch_size <= 0:
                            break

                        perturbation_metadata_for_gen = create_cell_type_metadata(
                            num_samples=current_batch_size,
                            perturbation_id=pert_id,
                            cell_type=ct_id_to_generate 
                        )
                        
                        with torch.no_grad():
                            generated_batch_tensor, gen_images_metadata = generate_perturbation_matched_samples(
                                model=model_for_generation,
                                p_id=pert_id,
                                p_metadata=perturbation_metadata_for_gen,
                                vae=vae_for_generation,
                                latent_size=latent_size_for_generation,
                                resolution=RESOLUTION,
                                latents_bias=latents_bias_for_generation,
                                latents_scale=latents_scale_for_generation,
                                path_type=path_type_for_generation,
                                device=device,
                            )
                        
                        temp_generated_images_list.append(generated_batch_tensor.cpu())
                        for meta_item in gen_images_metadata:
                            temp_generated_cell_type_ids_list.append(meta_item.get('cell_type_id', ct_id_to_generate))
                        
                        generated_count_for_this_ct += generated_batch_tensor.shape[0]

                if temp_generated_images_list:
                    generated_images_tensor = torch.cat(temp_generated_images_list, dim=0)
                    generated_images_cell_type_ids = np.array(temp_generated_cell_type_ids_list)
                    
                    # Ensure total generated samples do not exceed NUM_SAMPLES for this pert_id
                    if generated_images_tensor.shape[0] > NUM_SAMPLES:
                        print(f"  Generated {generated_images_tensor.shape[0]} images, sampling down to {NUM_SAMPLES}.")
                        indices = random.sample(range(generated_images_tensor.shape[0]), NUM_SAMPLES)
                        generated_images_tensor = generated_images_tensor[indices]
                        generated_images_cell_type_ids = generated_images_cell_type_ids[indices]
                    print(f"  Total {generated_images_tensor.shape[0]} images manually generated for p{pert_id}.")
        else:
            # --- Original Generated Data Loading with Cell Types (if not MANUAL_GENERATION) ---
            all_gen_files_with_cell_type_info = []
            for ct_id in CELL_TYPE_IDS_TO_PROCESS:
                files_for_ct = find_generated_files_by_perturbation_and_celltype(
                    generated_path,
                    pert_id,
                    ct_id,
                )
                for f_path in files_for_ct:
                    all_gen_files_with_cell_type_info.append(
                        {"path": f_path, "cell_type": ct_id}
                    )

            print(
                f"Found {len(all_gen_files_with_cell_type_info)} total generated files for perturbation ID {pert_id} across specified cell types (from path: {generated_path})."
            )

            if all_gen_files_with_cell_type_info:
                if len(all_gen_files_with_cell_type_info) > NUM_SAMPLES:
                    sampled_gen_files_info = random.sample(
                        all_gen_files_with_cell_type_info, NUM_SAMPLES
                    )
                    print(f"Sampled down to {NUM_SAMPLES} generated files.")
                else:
                    sampled_gen_files_info = all_gen_files_with_cell_type_info

                temp_generated_images_list = []
                temp_generated_cell_type_ids_list = []
                if sampled_gen_files_info:
                    for file_info in tqdm(
                        sampled_gen_files_info,
                        desc=f"Loading sampled generated files for p{pert_id}",
                    ):
                        try:
                            img = np.load(file_info["path"])
                            temp_generated_images_list.append(torch.from_numpy(img).float())
                            temp_generated_cell_type_ids_list.append(file_info["cell_type"])
                        except Exception as e:
                            print(f"Error loading {file_info['path']}: {e}")

                    if temp_generated_images_list:
                        generated_images_tensor = torch.stack(temp_generated_images_list)
                        generated_images_cell_type_ids = np.array(
                            temp_generated_cell_type_ids_list
                        )
        # --- End Generated Data Handling ---

        if generated_images_tensor is None or generated_images_tensor.size(0) == 0:
            print(f"No generated images available (loaded or created) for perturbation ID {pert_id}.")
            # Ensure generated_features is an empty array if no images, to prevent errors later
            generated_features = np.array([])
        else:
            print(f"Generated images tensor shape for p{pert_id}: {generated_images_tensor.shape}")
            print(f"Generated images cell type IDs shape for p{pert_id}: {generated_images_cell_type_ids.shape}")


        print(f"Extracting OpenPhenom features for perturbation ID {pert_id}")
        # Create directory for saving results if it doesn't exist
        os.makedirs("ophenom_qual_results/pca", exist_ok=True)

        # Extract OpenPhenom features for real and generated images in batches
        real_features_list = []
        if real_images_tensor is not None and real_images_tensor.size(0) > 0:
            with torch.no_grad():
                for i_batch in range(
                    0, real_images_tensor.size(0), BATCH_SIZE_FEATURES
                ):
                    batch_tensor = real_images_tensor[
                        i_batch : i_batch + BATCH_SIZE_FEATURES
                    ].to(device)
                    features_batch = encoder(batch_tensor).cpu().numpy()
                    real_features_list.append(features_batch)
            real_features = (
                np.vstack(real_features_list) if real_features_list else np.array([])
            )
        else:
            real_features = np.array([])

        generated_features_list = []
        if generated_images_tensor is not None and generated_images_tensor.size(0) > 0:
            with torch.no_grad():
                for i_batch in range(
                    0, generated_images_tensor.size(0), BATCH_SIZE_FEATURES
                ):
                    batch_tensor = generated_images_tensor[
                        i_batch : i_batch + BATCH_SIZE_FEATURES
                    ].to(device)
                    features_batch = encoder(batch_tensor).cpu().numpy()
                    generated_features_list.append(features_batch)
            generated_features = (
                np.vstack(generated_features_list)
                if generated_features_list
                else np.array([])
            )
        else:
            generated_features = np.array([])

        print(
            f"Extracted features shapes - Real: {real_features.shape}, Generated: {generated_features.shape}"
        )
        print(
            f"Cell type arrays shapes - Real: {real_images_cell_type_ids.shape}, Generated: {generated_images_cell_type_ids.shape}"
        )

        # Store features for the reference perturbation ID
        if pert_id == REFERENCE_PERT_ID:
            real_features_1138 = real_features
            generated_features_1138 = generated_features
            # Store cell types for reference perturbation if needed for specific plots later
            # real_cell_types_1138 = real_images_cell_type_ids
            # generated_cell_types_1138 = generated_images_cell_type_ids
            print(f"Stored features for reference perturbation ID {REFERENCE_PERT_ID}")

        # Store all features for aggregated PCA
        if real_features is not None and real_features.size > 0:
            all_real_features_list.append(real_features)
            all_pert_ids_for_real.extend(
                [f"p{pert_id}"] * len(real_features)
            )  # Store with 'p' prefix for clarity in plot
            all_cell_types_for_real.extend(
                real_images_cell_type_ids
            )  # Store cell types
        if generated_features is not None and generated_features.size > 0:
            all_generated_features_list.append(generated_features)
            all_pert_ids_for_generated.extend(
                [f"p{pert_id}"] * len(generated_features)
            )  # Store with 'p' prefix
            all_cell_types_for_generated.extend(
                generated_images_cell_type_ids
            )  # Store cell types

        # Combine features for PCA
        # Ensure labels match available features
        current_labels_list = []
        current_features_list = []
        current_cell_types_list = []  # For storing cell types for the current plot

        if real_features.size > 0:
            current_features_list.append(real_features)
            current_labels_list.extend(["Real"] * len(real_features))
            current_cell_types_list.extend(real_images_cell_type_ids)

        if generated_features.size > 0:
            current_features_list.append(generated_features)
            current_labels_list.extend(["Generated"] * len(generated_features))
            current_cell_types_list.extend(generated_images_cell_type_ids)

        if not current_features_list:
            print(
                f"No features (real or generated) to plot for p{pert_id}. Skipping PCA plot."
            )
            # Save empty features and cell types if needed, or just continue
            # np.save(f"ophenom_qual_results/pca/features_real_p{pert_id}.npy", real_features) # real_features would be empty
            # np.save(f"ophenom_qual_results/pca/celltypes_real_p{pert_id}.npy", real_images_cell_type_ids)
            # np.save(f"ophenom_qual_results/pca/features_generated_p{pert_id}.npy", generated_features)
            # np.save(f"ophenom_qual_results/pca/celltypes_generated_p{pert_id}.npy", generated_images_cell_type_ids)
            continue  # Skip to next perturbation ID if no features

        combined_features_pert = np.vstack(current_features_list)
        type_labels_pert = np.array(current_labels_list)
        cell_type_labels_pert = np.array(current_cell_types_list)

        if combined_features_pert.shape[0] < 2:  # Need at least 2 samples for PCA
            print(
                f"Not enough combined samples ({combined_features_pert.shape[0]}) for PCA for p{pert_id}. Skipping plot."
            )
            # Save features and cell types even if not plotting
            if real_features.size > 0:
                np.save(
                    f"ophenom_qual_results/pca/features_real_p{pert_id}.npy",
                    real_features,
                )
                np.save(
                    f"ophenom_qual_results/pca/celltypes_real_p{pert_id}.npy",
                    real_images_cell_type_ids,
                )
            if generated_features.size > 0:
                np.save(
                    f"ophenom_qual_results/pca/features_generated_p{pert_id}.npy",
                    generated_features,
                )
                np.save(
                    f"ophenom_qual_results/pca/celltypes_generated_p{pert_id}.npy",
                    generated_images_cell_type_ids,
                )
            continue

        # Apply PCA
        scaler_pert = StandardScaler()
        scaled_features_pert = scaler_pert.fit_transform(combined_features_pert)
        pca_pert = PCA(n_components=2)
        embedding_pert = pca_pert.fit_transform(scaled_features_pert)
        var_explained_pert = pca_pert.explained_variance_ratio_ * 100

        # Create DataFrame for plotting
        plot_df_pert = pd.DataFrame(
            {
                "PC1": embedding_pert[:, 0],
                "PC2": embedding_pert[:, 1],
                "Type": type_labels_pert,
                "CellType": cell_type_labels_pert,  # Add CellType for hue
                # "Perturbation": f"p{pert_id}" # Perturbation is in the title
            }
        )

        # Create PCA plot
        plt.figure(figsize=(12, 10))  # Slightly larger for better legend

        # Create scatter plot with CellType as hue
        sns.scatterplot(
            data=plot_df_pert,
            x="PC1",
            y="PC2",
            hue="CellType",
            palette=GLOBAL_CELL_TYPE_COLOR_MAP,
            style="Type",  # Added to differentiate Real/Generated by marker
            markers=GLOBAL_TYPE_MARKERS,  # Added to specify markers for Real/Generated
            alpha=0.7,
            s=70,  # Adjusted size
            edgecolor=None,
        )

        plt.xlabel(f"PC1 ({var_explained_pert[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.ylabel(f"PC2 ({var_explained_pert[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.title(
            f"PCA of OpenPhenom Features - Perturbation {pert_id}",
            fontsize=TITLE_FONTSIZE,
        )

        legend_ct = plt.legend(
            loc="upper right", fontsize=LEGEND_FONTSIZE
        )  # No explicit title here
        # Iterate over legend text objects to customize section titles
        for text_obj in legend_ct.findobj(plt.Text):
            if text_obj.get_text() == "CellType":
                text_obj.set_text("Cell Type")  # Correct the text
                text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)  # Apply title font size
            elif text_obj.get_text() == "Type":
                text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)  # Apply title font size

        plt.xticks(fontsize=TICK_FONTSIZE)
        plt.yticks(fontsize=TICK_FONTSIZE)
        # plt.tight_layout()

        # Save the plot
        output_file_by_celltype = (
            f"ophenom_qual_results/pca/pca_perturbation_{pert_id}_by_celltype.pdf"
        )
        plt.savefig(output_file_by_celltype, dpi=600)
        plt.close()
        print(f"Saved PCA plot (colored by cell type) to {output_file_by_celltype}")

        # --- Plot 2: Color by Type (Real/Generated) ---
        plt.figure(figsize=(12, 10))
        sns.scatterplot(
            data=plot_df_pert,
            x="PC1",
            y="PC2",
            hue="Type",  # Color by Real/Generated
            palette=REAL_GENERATED_COLOR_MAP,
            alpha=0.7,
            s=70,
            edgecolor=None,
        )
        plt.xlabel(f"PC1 ({var_explained_pert[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.ylabel(f"PC2 ({var_explained_pert[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.title(
            f"PCA of OpenPhenom Features - Perturbation {pert_id}",
            fontsize=TITLE_FONTSIZE,
        )
        legend_type = plt.legend(
            loc="upper right", title="Type", fontsize=LEGEND_FONTSIZE
        )
        plt.setp(legend_type.get_title(), fontsize=LEGEND_TITLE_FONTSIZE)
        plt.xticks(fontsize=TICK_FONTSIZE)
        plt.yticks(fontsize=TICK_FONTSIZE)
        # plt.tight_layout()

        output_file_by_type = (
            f"ophenom_qual_results/pca/pca_perturbation_{pert_id}_by_type.pdf"
        )
        plt.savefig(output_file_by_type, dpi=600)
        plt.close()
        print(f"Saved PCA plot (colored by type) to {output_file_by_type}")

        # Also save the features AND cell types for potential future use
        if real_features.size > 0:
            np.save(
                f"ophenom_qual_results/pca/features_real_p{pert_id}.npy", real_features
            )
            np.save(
                f"ophenom_qual_results/pca/celltypes_real_p{pert_id}.npy",
                real_images_cell_type_ids,
            )
        else:  # Save empty arrays if no real features, to prevent load errors later if expected
            np.save(
                f"ophenom_qual_results/pca/features_real_p{pert_id}.npy", np.array([])
            )
            np.save(
                f"ophenom_qual_results/pca/celltypes_real_p{pert_id}.npy", np.array([])
            )

        if generated_features.size > 0:
            np.save(
                f"ophenom_qual_results/pca/features_generated_p{pert_id}.npy",
                generated_features,
            )
            np.save(
                f"ophenom_qual_results/pca/celltypes_generated_p{pert_id}.npy",
                generated_images_cell_type_ids,
            )
        else:  # Save empty arrays if no generated features
            np.save(
                f"ophenom_qual_results/pca/features_generated_p{pert_id}.npy",
                np.array([]),
            )
            np.save(
                f"ophenom_qual_results/pca/celltypes_generated_p{pert_id}.npy",
                np.array([]),
            )

    # --- Comparisons against Perturbation 1138 ---
    # This section might need adjustment or removal if it relied on data structures changed above,
    # or if its plots are no longer desired as per the user rolling back dual plots.
    # For now, I will assume it needs to load cell types if it's still active.
    if real_features_1138 is not None and generated_features_1138 is not None:
        print(f"\n\n{'='*80}")
        print(f"Starting comparisons against Perturbation ID {REFERENCE_PERT_ID}")
        print(f"{'='*80}")

        for other_pert_id in sampled_perturbation_ids:
            if other_pert_id == REFERENCE_PERT_ID:
                continue

            print(f"\n--- Comparing p{REFERENCE_PERT_ID} with p{other_pert_id} ---")

            # Load features for the "other" perturbation
            try:
                real_features_other_path = (
                    f"ophenom_qual_results/pca/features_real_p{other_pert_id}.npy"
                )
                # Attempt to load cell types for other perturbation if this section is still used
                real_cell_types_other_path = (
                    f"ophenom_qual_results/pca/celltypes_real_p{other_pert_id}.npy"
                )

                generated_features_other_path = (
                    f"ophenom_qual_results/pca/features_generated_p{other_pert_id}.npy"
                )
                generated_cell_types_other_path = (
                    f"ophenom_qual_results/pca/celltypes_generated_p{other_pert_id}.npy"
                )

                if not (
                    os.path.exists(real_features_other_path)
                    # and os.path.exists(real_cell_types_other_path) # Add if comparisons need cell types
                    and os.path.exists(generated_features_other_path)
                    # and os.path.exists(generated_cell_types_other_path) # Add if comparisons need cell types
                ):
                    print(
                        f"Feature files for p{other_pert_id} not found at {real_features_other_path} or {generated_features_other_path}. Skipping comparison."
                    )
                    continue

                real_features_other = np.load(real_features_other_path)
                # real_cell_types_other = np.load(real_cell_types_other_path) # Add if needed
                generated_features_other = np.load(generated_features_other_path)
                # generated_cell_types_other = np.load(generated_cell_types_other_path) # Add if needed
                print(
                    f"Loaded features for p{other_pert_id}"
                )  # Modify if cell types loaded

            except FileNotFoundError:
                print(
                    f"Feature files not found for perturbation {other_pert_id}. Skipping comparison."
                )
                continue
            except Exception as e:
                print(
                    f"Error loading features for perturbation {other_pert_id}: {e}. Skipping comparison."
                )
                continue

            # New Dual Comparison Plot - This was part of the rolled-back changes.
            # If it's to be kept, it needs cell type arrays similar to how they were added before.
            # plot_pca_dual_perturbation_comparison(
            #     real_features_1138, generated_features_1138,
            #     real_features_other, generated_features_other,
            #     real_cell_types_1138, generated_cell_types_1138, # Requires these to be stored for 1138
            #     real_cell_types_other, generated_cell_types_other,
            #     REFERENCE_PERT_ID, other_pert_id,
            #     cell_type_color_map=GLOBAL_CELL_TYPE_COLOR_MAP,
            #     type_markers_map=GLOBAL_TYPE_MARKERS
            # )

            # --- ATE Calculation ---
            if (
                real_features_1138 is not None
                and real_features_other is not None
                and real_features_1138.size > 0
                and real_features_other.size > 0
            ):
                mean_real_control = np.mean(real_features_1138, axis=0)
                mean_real_other = np.mean(real_features_other, axis=0)
                ate_real = np.linalg.norm(mean_real_control - mean_real_other) ** 2
                print(
                    f"ATE_real (p{REFERENCE_PERT_ID} vs p{other_pert_id}): {ate_real:.4f}"
                )
            else:
                ate_real = np.nan
                print(
                    f"Could not calculate ATE_real for p{other_pert_id} due to missing/empty real features."
                )

            if (
                generated_features_1138 is not None
                and generated_features_other is not None
                and generated_features_1138.size > 0
                and generated_features_other.size > 0
            ):
                mean_generated_control = np.mean(generated_features_1138, axis=0)
                mean_generated_other = np.mean(generated_features_other, axis=0)
                ate_generated = (
                    np.linalg.norm(mean_generated_control - mean_generated_other) ** 2
                )
                print(
                    f"ATE_generated (p{REFERENCE_PERT_ID} vs p{other_pert_id}): {ate_generated:.4f}"
                )
            else:
                ate_generated = np.nan
                print(
                    f"Could not calculate ATE_generated for p{other_pert_id} due to missing/empty generated features."
                )

            ate_scores_list.append(
                {
                    "comparison": f"p{REFERENCE_PERT_ID}_vs_p{other_pert_id}",
                    "ATE_real": ate_real,
                    "ATE_generated": ate_generated,
                }
            )

    else:
        print(
            f"Features for reference perturbation ID {REFERENCE_PERT_ID} were not found. Skipping comparisons."
        )

    # --- PCA for All Features Combined ---
    if all_real_features_list and all_generated_features_list:
        print(f"\n\n{'='*80}")
        print("Starting PCA for all aggregated features")
        print(f"{'='*80}")

        final_all_real_features = np.vstack(all_real_features_list)
        final_all_generated_features = np.vstack(all_generated_features_list)

        combined_all_features = np.vstack(
            (final_all_real_features, final_all_generated_features)
        )

        type_labels_all = ["Real"] * len(final_all_real_features) + ["Generated"] * len(
            final_all_generated_features
        )
        pert_id_labels_all = all_pert_ids_for_real + all_pert_ids_for_generated

        if combined_all_features.shape[0] > 1:  # Need at least 2 samples for PCA
            scaler_all = StandardScaler()
            scaled_all_features = scaler_all.fit_transform(combined_all_features)

            pca_all = PCA(n_components=2)
            embedding_all = pca_all.fit_transform(scaled_all_features)
            var_explained_all = pca_all.explained_variance_ratio_ * 100

            plot_all_df = pd.DataFrame(
                {
                    "PC1": embedding_all[:, 0],
                    "PC2": embedding_all[:, 1],
                    "Type": type_labels_all,
                    "PerturbationID": pert_id_labels_all,
                }
            )

            plt.figure(figsize=(12, 10))

            unique_perturbations = sorted(plot_all_df["PerturbationID"].unique())
            # palette = sns.color_palette(n_colors=len(unique_perturbations))
            # Create a mapping from perturbation ID to color
            # perturbation_palette = {
            #     pert_id: palette[i] for i, pert_id in enumerate(unique_perturbations)
            # }
            perturbation_palette = {}
            color_idx = 0
            for pert_id_str in unique_perturbations:
                # Extract numeric part of pert_id_str, e.g., 'p1138' -> 1138
                try:
                    numeric_pert_id = int(pert_id_str.replace("p", ""))
                    if numeric_pert_id == 1138:
                        perturbation_palette[pert_id_str] = "gray"
                    else:
                        perturbation_palette[pert_id_str] = sns.color_palette()[
                            color_idx
                        ]
                        color_idx = (color_idx + 1) % len(
                            sns.color_palette()
                        )  # Cycle through palette colors
                except ValueError:
                    # Handle cases where pert_id_str might not be in the expected format 'p<number>'
                    perturbation_palette[pert_id_str] = sns.color_palette()[color_idx]
                    color_idx = (color_idx + 1) % len(sns.color_palette())

            sns.scatterplot(
                data=plot_all_df,
                x="PC1",
                y="PC2",
                hue="PerturbationID",
                style="Type",
                palette=GLOBAL_PERTURBATION_COLOR_MAP,  # Use global color map
                markers=GLOBAL_TYPE_MARKERS,  # Use global type markers
                alpha=0.7,
                s=80,  # Increased marker size
                edgecolor=None,
            )
            plt.xlabel(
                f"PC1 ({var_explained_all[0]:.1f}%)", fontsize=LABEL_FONTSIZE
            )  # Modified
            plt.ylabel(
                f"PC2 ({var_explained_all[1]:.1f}%)", fontsize=LABEL_FONTSIZE
            )  # Modified
            plt.title(
                "Aggregated PCA of Real and Generated Features by Perturbation",
                fontsize=TITLE_FONTSIZE,
            )  # Modified
            plt.legend(loc="upper right")
            # plt.tight_layout(
            #     rect=[0, 0, 0.85, 1]
            # )  # Adjust layout to make space for legend

            aggregated_pca_save_path = "ophenom_qual_results/pca_aggregated"
            os.makedirs(aggregated_pca_save_path, exist_ok=True)
            output_file_all = os.path.join(
                aggregated_pca_save_path, "pca_all_perturbations_combined.pdf"
            )
            plt.savefig(output_file_all, dpi=600)
            plt.close()
            print(f"Saved aggregated PCA plot to {output_file_all}")
        else:
            print(
                "Not enough data (less than 2 samples) to generate aggregated PCA plot."
            )
    else:
        print("No features collected to generate aggregated PCA plot.")

    # --- Print ATE Scores Summary ---
    if ate_scores_list:
        print(f"\n\n{'='*80}")
        print("Average Treatment Effect (ATE) Scores Summary")
        print(f"Control Perturbation ID: {REFERENCE_PERT_ID}")
        print(f"{'='*80}")
        ate_df = pd.DataFrame(ate_scores_list)
        print(ate_df.to_string())

        # Save ATE scores to CSV
        ate_csv_path = "ophenom_qual_results/ate_scores.csv"
        os.makedirs(os.path.dirname(ate_csv_path), exist_ok=True)
        ate_df.to_csv(ate_csv_path, index=False)
        print(f"\nATE scores saved to {ate_csv_path}")

    print("\n\nScript finished.")
