import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from dataset import CellDataModule, to_rgb
from diffusers.models import AutoencoderKL
from metrics_utils import calculate_metrics_from_scratch
from models.sit import SiT_models
from omegaconf import OmegaConf
from openphenom import OpenPhenomEncoder
from pytorch_lightning import seed_everything
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from train import generate_perturbation_matched_samples
from utils import load_encoders

# write a dummy custom dataset

organelle_id_to_name = {
    0: "Nuclei",
    1: "ER",
    2: "Actin",
    3: "Nucleoli",
    4: "Mitochandria",
    5: "Golgi",
}


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.

    Args:
        generated_path: Path to the directory containing generated data
        perturbation_id: Perturbation ID to filter by
        cell_type_id: Cell type ID to filter by (0-3)

    Returns:
        List of file paths matching both the perturbation and cell type
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def load_numpy_files(file_paths, max_samples):
    """
    Load a random subset of numpy files into a torch tensor.

    Args:
        file_paths: List of numpy file paths to load
        max_samples: Maximum number of samples to load

    Returns:
        Torch tensor containing the loaded data
    """
    # Randomly sample file paths if there are more than max_samples
    if len(file_paths) > max_samples:
        file_paths = random.sample(file_paths, max_samples)

    # Load the numpy files
    data = []
    for file_path in tqdm(file_paths, desc="Loading numpy files"):
        try:
            img = np.load(file_path)
            data.append(torch.from_numpy(img).float())
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Stack into a single tensor
    return torch.stack(data) if data else None


def create_cell_type_metadata(num_samples=500, perturbation_id=1138, cell_type=1):
    """
    Creates a perturbation metadata list with all cell types set to 1.

    Args:
        num_samples: Number of metadata entries to create
        perturbation_id: The perturbation ID to use for all entries

    Returns:
        List of metadata dictionaries with cell_type_id set to 1
    """
    perturbation_metadata = []

    for i in range(num_samples):
        metadata_entry = {
            "perturbation_id": perturbation_id,
            "cell_type_id": 1,  # All cell types set to 1 as requested
            "is_generated": False,  # This is typically False for real data
        }
        perturbation_metadata.append(metadata_entry)

    print(f"Created perturbation metadata with {len(perturbation_metadata)} entries")
    print(f"All entries have cell_type_id set to 1")

    return perturbation_metadata


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, 180, or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


def plot_pca_comparison(
    features1,
    features2,
    label1_desc,
    label2_desc,
    title_pert_id1,
    title_pert_id2,
    plot_title_prefix,
    output_filename_prefix,
    base_save_path="ophenom_qual_results/pca_comparisons",
):
    """
    Generates and saves a PCA plot comparing two sets of features.
    """
    os.makedirs(base_save_path, exist_ok=True)

    combined_features = np.vstack((features1, features2))
    # Create labels for the plot legend
    labels_plot = np.array(
        [f"{label1_desc} (p{title_pert_id1})"] * len(features1)
        + [f"{label2_desc} (p{title_pert_id2})"] * len(features2)
    )

    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(combined_features)

    pca = PCA(n_components=2)
    embedding = pca.fit_transform(scaled_features)
    var_explained = pca.explained_variance_ratio_ * 100

    plot_df = pd.DataFrame(
        {
            "PC1": embedding[:, 0],
            "PC2": embedding[:, 1],
            "Type": labels_plot,  # Use the descriptive labels for hue
        }
    )

    plt.figure(figsize=(12, 10))
    # Dynamically create palette based on unique labels
    unique_plot_labels = plot_df["Type"].unique()
    palette = {ul: sns.color_palette()[i] for i, ul in enumerate(unique_plot_labels)}

    sns.scatterplot(
        data=plot_df,
        x="PC1",
        y="PC2",
        hue="Type",  # Hue is based on the combined descriptive label
        palette=palette,
        alpha=0.7,
        s=50,
        edgecolor=None,
    )

    plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)")
    plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)")
    plt.title(f"PCA: {plot_title_prefix} - p{title_pert_id1} vs p{title_pert_id2}")
    plt.legend(loc="upper right")
    # plt.tight_layout()

    output_file = os.path.join(
        base_save_path,
        f"{output_filename_prefix}_p{title_pert_id1}_vs_p{title_pert_id2}.pdf",
    )
    plt.savefig(output_file, dpi=600)
    plt.close()
    print(f"Saved PCA comparison plot to {output_file}")


def plot_pca_dual_perturbation_comparison(
    real_features_ref,
    generated_features_ref,
    real_features_other,
    generated_features_other,
    ref_pert_id,
    other_pert_id,
    organelle_idx,  # Added
    color_map,
    type_markers_map,
    base_save_path_template="ophenom_qual_results_organelle/organelle_{org_idx}/pca_comparisons_dual",  # Modified template
):
    """
    Generates and saves a PCA plot comparing real/generated features of two perturbations for a specific organelle.
    Hue by PerturbationID, Style by Type (Real/Generated).
    """
    base_save_path = base_save_path_template.format(
        org_idx=organelle_idx
    )  # Format path
    os.makedirs(base_save_path, exist_ok=True)

    # Combine all features
    features_list = []
    pert_labels_list = []
    type_labels_list = []

    # Reference perturbation data
    if real_features_ref is not None and real_features_ref.size > 0:
        features_list.append(real_features_ref)
        pert_labels_list.extend([f"p{ref_pert_id}"] * len(real_features_ref))
        type_labels_list.extend(["Real"] * len(real_features_ref))
    if generated_features_ref is not None and generated_features_ref.size > 0:
        features_list.append(generated_features_ref)
        pert_labels_list.extend([f"p{ref_pert_id}"] * len(generated_features_ref))
        type_labels_list.extend(["Generated"] * len(generated_features_ref))

    # Other perturbation data
    if real_features_other is not None and real_features_other.size > 0:
        features_list.append(real_features_other)
        pert_labels_list.extend([f"p{other_pert_id}"] * len(real_features_other))
        type_labels_list.extend(["Real"] * len(real_features_other))
    if generated_features_other is not None and generated_features_other.size > 0:
        features_list.append(generated_features_other)
        pert_labels_list.extend([f"p{other_pert_id}"] * len(generated_features_other))
        type_labels_list.extend(["Generated"] * len(generated_features_other))

    if not features_list:
        print(
            f"No features provided for PCA dual comparison between p{ref_pert_id} and p{other_pert_id}. Skipping."
        )
        return

    combined_features = np.vstack(features_list)

    if combined_features.shape[0] < 2:
        print(
            f"Not enough samples ({combined_features.shape[0]}) for PCA dual comparison between p{ref_pert_id} and p{other_pert_id}. Skipping."
        )
        return

    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(combined_features)

    pca = PCA(n_components=2)
    embedding = pca.fit_transform(scaled_features)
    var_explained = pca.explained_variance_ratio_ * 100

    plot_df = pd.DataFrame(
        {
            "PC1": embedding[:, 0],
            "PC2": embedding[:, 1],
            "PerturbationID": pert_labels_list,
            "Type": type_labels_list,
        }
    )

    plt.figure(figsize=(12, 10))

    # Use the passed-in color_map and type_markers_map
    current_perturbations = [f"p{ref_pert_id}", f"p{other_pert_id}"]
    active_color_palette = {
        p: color_map.get(p) for p in current_perturbations if p in color_map
    }

    sns.scatterplot(
        data=plot_df,
        x="PC1",
        y="PC2",
        hue="PerturbationID",
        style="Type",
        palette=active_color_palette,  # Use filtered global color map
        markers=type_markers_map,  # Use global type markers
        alpha=0.7,
        s=90,  # Increased marker size
        edgecolor=None,
    )

    plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)")
    plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)")
    plt.title(
        f"PCA: p{ref_pert_id} vs p{other_pert_id} ({organelle_id_to_name[organelle_idx]})",
    )  # Modified title
    plt.legend(loc="upper right")
    # plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout for legend

    output_file = os.path.join(
        base_save_path,
        f"pca_dual_p{ref_pert_id}_vs_p{other_pert_id}_{organelle_id_to_name[organelle_idx]}.pdf",  # Modified filename
    )
    plt.savefig(output_file, dpi=600)
    plt.close()
    print(f"Saved dual PCA comparison plot to {output_file}")


if __name__ == "__main__":
    seed = 0
    MANUAL_GENERATION = False
    cell_type_id = 1  # This script focuses on one cell_type for organelle analysis
    seed_everything(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Define font sizes for plots
    TITLE_FONTSIZE = 26
    LABEL_FONTSIZE = 22
    LEGEND_FONTSIZE = 18
    LEGEND_TITLE_FONTSIZE = 18
    TICK_FONTSIZE = 18

    NUM_ORGANELLES = 6  # Define number of organelles

    # Sample 4 random perturbation IDs out of 1138
    # all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    # sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    encoder = OpenPhenomEncoder().to(device)
    # Iterate through each perturbation ID and calculate metrics
    results = []

    # Store 3D features (samples, organelles, dim) for the reference perturbation
    real_features_1138_all_organs = None
    generated_features_1138_all_organs = None
    REFERENCE_PERT_ID = 1138

    # --- Create Global Color Map and Marker Map for Consistent Plotting ---
    GLOBAL_PERTURBATION_COLOR_MAP = {}
    # Ensure sampled_perturbation_ids are actual numeric IDs if they are not already
    # For this example, assuming sampled_perturbation_ids contains numeric IDs like [1138, 100, 200]
    # If they are strings like ['p1138'], adjust accordingly.

    # Get unique sorted numeric perturbation IDs to ensure consistent color assignment
    # This assumes sampled_perturbation_ids contains the numeric values
    unique_numeric_ids = sorted(list(set(sampled_perturbation_ids)))

    color_palette_list = sns.color_palette()  # Get a list of colors
    color_idx_non_ref = 0

    for pid_numeric in unique_numeric_ids:
        pid_str = f"p{pid_numeric}"  # Construct string ID like 'p1138'
        if pid_numeric == REFERENCE_PERT_ID:
            GLOBAL_PERTURBATION_COLOR_MAP[pid_str] = "gray"
        else:
            # Cycle through the seaborn palette for non-reference perturbations
            GLOBAL_PERTURBATION_COLOR_MAP[pid_str] = color_palette_list[
                color_idx_non_ref % len(color_palette_list)
            ]
            color_idx_non_ref += 1

    GLOBAL_TYPE_MARKERS = {"Real": "o", "Generated": "X"}
    print(f"Global Perturbation Color Map: {GLOBAL_PERTURBATION_COLOR_MAP}")
    print(f"Global Type Markers: {GLOBAL_TYPE_MARKERS}")
    # --- End Global Map Creation ---

    # Initialize lists to store features per organelle for aggregated PCA
    all_real_features_per_org_lists = [[] for _ in range(NUM_ORGANELLES)]
    all_generated_features_per_org_lists = [[] for _ in range(NUM_ORGANELLES)]
    all_pert_ids_for_real_per_org_lists = [[] for _ in range(NUM_ORGANELLES)]
    all_pert_ids_for_generated_per_org_lists = [[] for _ in range(NUM_ORGANELLES)]

    ate_scores_list = []  # To store ATE scores (will include organelle_id)

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id, cell_type_id=cell_type_id
        )

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
            continue

        # Get real images
        real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(f"Found {len(real_images)} real images for perturbation ID {pert_id}")
        # Convert to tensor
        real_images_tensor = torch.stack(real_images)

        # Find all generated files for this perturbation
        cell_types = [cell_type_id]  # [0, 1, 2, 3]
        generated_files = []
        for cell_type_id in cell_types:
            generated_files_ct = find_generated_files_by_perturbation_and_celltype(
                generated_path,
                pert_id,
                cell_type_id,
            )
            generated_files.extend(generated_files_ct)
        print(
            f"Found {len(generated_files)} generated files for perturbation ID {pert_id}"
        )

        if not generated_files:
            print(f"No generated data found for perturbation ID {pert_id}")
            continue

        # Load generated images (sample up to NUM_SAMPLES)
        generated_images_tensor = load_numpy_files(
            generated_files, max_samples=NUM_SAMPLES
        )

        if generated_images_tensor is None:
            print(f"Failed to load generated images for perturbation ID {pert_id}")
            continue

        print(f"Extracting OpenPhenom features for perturbation ID {pert_id}")
        # Create directory for saving results if it doesn't exist
        os.makedirs("ophenom_qual_results/pca", exist_ok=True)

        # Extract OpenPhenom features for real and generated images
        # These will be 3D tensors: (num_samples, num_organelles, feature_dim)
        real_features_all_organs_current_pert = None
        generated_features_all_organs_current_pert = None
        with torch.no_grad():
            if real_images_tensor is not None and real_images_tensor.size(0) > 0:
                real_features_all_organs_current_pert = (
                    encoder.forward_features(
                        real_images_tensor.to(device), flatten=False
                    )
                    .cpu()
                    .numpy()
                )
            if (
                generated_images_tensor is not None
                and generated_images_tensor.size(0) > 0
            ):
                generated_features_all_organs_current_pert = (
                    encoder.forward_features(
                        generated_images_tensor.to(device), flatten=False
                    )
                    .cpu()
                    .numpy()
                )

        if real_features_all_organs_current_pert is not None:
            print(
                f"Extracted real features shape (all organelles): {real_features_all_organs_current_pert.shape}"
            )
        if generated_features_all_organs_current_pert is not None:
            print(
                f"Extracted generated features shape (all organelles): {generated_features_all_organs_current_pert.shape}"
            )

        # Store 3D features for the reference perturbation ID
        if pert_id == REFERENCE_PERT_ID:
            real_features_1138_all_organs = real_features_all_organs_current_pert
            generated_features_1138_all_organs = (
                generated_features_all_organs_current_pert
            )
            print(
                f"Stored 3D features for reference perturbation ID {REFERENCE_PERT_ID}"
            )

        # Loop through each organelle for PCA and feature storage
        for organelle_idx in range(NUM_ORGANELLES):
            print(
                f"\n--- Processing Organelle ID: {organelle_idx} for Perturbation ID: {pert_id} ---"
            )

            base_organelle_save_path = (
                f"ophenom_qual_results_organelle/organelle_{organelle_idx}"
            )
            pca_save_path = os.path.join(base_organelle_save_path, "pca")
            os.makedirs(pca_save_path, exist_ok=True)

            # Select features for the current organelle
            real_features_org = None
            if real_features_all_organs_current_pert is not None:
                real_features_org = real_features_all_organs_current_pert[
                    :, organelle_idx, :
                ]

            generated_features_org = None
            if generated_features_all_organs_current_pert is not None:
                generated_features_org = generated_features_all_organs_current_pert[
                    :, organelle_idx, :
                ]

            if (real_features_org is None or real_features_org.size == 0) and (
                generated_features_org is None or generated_features_org.size == 0
            ):
                print(
                    f"No features for organelle {organelle_idx}, perturbation {pert_id}. Skipping."
                )
                continue

            print(
                f"Organelle {organelle_idx} features shapes - Real: {real_features_org.shape if real_features_org is not None else 'N/A'}, Generated: {generated_features_org.shape if generated_features_org is not None else 'N/A'}"
            )

            # Store organelle-specific features for aggregated PCA
            if real_features_org is not None and real_features_org.size > 0:
                all_real_features_per_org_lists[organelle_idx].append(real_features_org)
                all_pert_ids_for_real_per_org_lists[organelle_idx].extend(
                    [f"p{pert_id}"] * len(real_features_org)
                )
            if generated_features_org is not None and generated_features_org.size > 0:
                all_generated_features_per_org_lists[organelle_idx].append(
                    generated_features_org
                )
                all_pert_ids_for_generated_per_org_lists[organelle_idx].extend(
                    [f"p{pert_id}"] * len(generated_features_org)
                )

            # Combine features for PCA for the current organelle
            current_features_list_org = []
            current_labels_list_org = []
            if real_features_org is not None and real_features_org.size > 0:
                current_features_list_org.append(real_features_org)
                current_labels_list_org.extend(["Real"] * len(real_features_org))
            if generated_features_org is not None and generated_features_org.size > 0:
                current_features_list_org.append(generated_features_org)
                current_labels_list_org.extend(
                    ["Generated"] * len(generated_features_org)
                )

            if not current_features_list_org:
                print(
                    f"No features to plot for p{pert_id}, organelle {organelle_idx}. Skipping PCA plot."
                )
                # Save empty features if needed, or just continue
                if real_features_org is not None:
                    np.save(
                        os.path.join(
                            pca_save_path,
                            f"features_real_p{pert_id}_org{organelle_idx}.npy",
                        ),
                        real_features_org,
                    )
                if generated_features_org is not None:
                    np.save(
                        os.path.join(
                            pca_save_path,
                            f"features_generated_p{pert_id}_org{organelle_idx}.npy",
                        ),
                        generated_features_org,
                    )
                continue

            combined_features_org = np.vstack(current_features_list_org)
            type_labels_org = np.array(current_labels_list_org)

            if combined_features_org.shape[0] < 2:  # Need at least 2 samples for PCA
                print(
                    f"Not enough combined samples ({combined_features_org.shape[0]}) for PCA for p{pert_id}, organelle {organelle_idx}. Skipping plot."
                )
                if real_features_org is not None:
                    np.save(
                        os.path.join(
                            pca_save_path,
                            f"features_real_p{pert_id}_org{organelle_idx}.npy",
                        ),
                        real_features_org,
                    )
                if generated_features_org is not None:
                    np.save(
                        os.path.join(
                            pca_save_path,
                            f"features_generated_p{pert_id}_org{organelle_idx}.npy",
                        ),
                        generated_features_org,
                    )
                continue

            # Apply PCA for the current organelle
            scaler_org = StandardScaler()
            scaled_features_org = scaler_org.fit_transform(combined_features_org)
            pca_org = PCA(n_components=2)
            embedding_org = pca_org.fit_transform(scaled_features_org)
            var_explained_org = pca_org.explained_variance_ratio_ * 100

            # Create DataFrame for plotting for the current organelle
            plot_df_org = pd.DataFrame(
                {
                    "PC1": embedding_org[:, 0],
                    "PC2": embedding_org[:, 1],
                    "Type": type_labels_org,
                    # "Perturbation": f"p{pert_id}", # Perturbation is in the title
                }
            )

            # Create PCA plot for the current organelle
            plt.figure(figsize=(12, 10))
            palette_org = {"Real": "blue", "Generated": "red"}

            sns.scatterplot(
                data=plot_df_org,
                x="PC1",
                y="PC2",
                hue="Type",
                palette=palette_org,
                style="Type",  # Differentiate by marker as well
                markers=GLOBAL_TYPE_MARKERS,  # Using global markers
                alpha=0.7,
                s=50,
                edgecolor=None,
            )

            plt.xlabel(f"PC1 ({var_explained_org[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
            plt.ylabel(f"PC2 ({var_explained_org[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
            plt.title(
                f"PCA of OpenPhenom Features - Perturbation {pert_id}, {organelle_id_to_name[organelle_idx]}",
                fontsize=TITLE_FONTSIZE,
            )
            legend_obj_individual_org = plt.legend(loc="upper right")
            for text_obj in legend_obj_individual_org.findobj(plt.Text):
                current_text = text_obj.get_text()
                if current_text == "Type":
                    text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)
                else:
                    text_obj.set_fontsize(LEGEND_FONTSIZE)

            plt.xticks(fontsize=TICK_FONTSIZE)
            plt.yticks(fontsize=TICK_FONTSIZE)
            # plt.tight_layout()

            output_file_org = os.path.join(
                pca_save_path, f"pca_perturbation_{pert_id}_org{organelle_idx}.pdf"
            )
            plt.savefig(output_file_org, dpi=600)
            plt.close()
            print(f"Saved PCA plot for organelle {organelle_idx} to {output_file_org}")

            # Save the organelle-specific features
            if real_features_org is not None:
                np.save(
                    os.path.join(
                        pca_save_path,
                        f"features_real_p{pert_id}_org{organelle_idx}.npy",
                    ),
                    real_features_org,
                )
            if generated_features_org is not None:
                np.save(
                    os.path.join(
                        pca_save_path,
                        f"features_generated_p{pert_id}_org{organelle_idx}.npy",
                    ),
                    generated_features_org,
                )

    # --- Comparisons against Perturbation 1138 (per organelle) ---
    if (
        real_features_1138_all_organs is not None
        and generated_features_1138_all_organs is not None
    ):
        print(f"\n\n{'='*80}")
        print(
            f"Starting comparisons against Perturbation ID {REFERENCE_PERT_ID} (per organelle)"
        )
        print(f"{'='*80}")

        for organelle_idx in range(NUM_ORGANELLES):
            print(f"\n--- Processing Comparisons for Organelle ID: {organelle_idx} ---")

            real_features_1138_org = real_features_1138_all_organs[:, organelle_idx, :]
            generated_features_1138_org = generated_features_1138_all_organs[
                :, organelle_idx, :
            ]

            if (
                real_features_1138_org.size == 0
                and generated_features_1138_org.size == 0
            ):
                print(
                    f"No reference features for organelle {organelle_idx}. Skipping comparisons for this organelle."
                )
                continue

            for other_pert_id in sampled_perturbation_ids:
                if other_pert_id == REFERENCE_PERT_ID:
                    continue

                print(
                    f"\n--- Comparing p{REFERENCE_PERT_ID} with p{other_pert_id} for Organelle {organelle_idx} ---"
                )

                # Define paths for organelle-specific features
                base_organelle_save_path_other = f"ophenom_qual_results_organelle/organelle_{organelle_idx}"  # Path for current organelle
                pca_save_path_other = os.path.join(
                    base_organelle_save_path_other, "pca"
                )

                real_features_other_org_path = os.path.join(
                    pca_save_path_other,
                    f"features_real_p{other_pert_id}_org{organelle_idx}.npy",
                )
                generated_features_other_org_path = os.path.join(
                    pca_save_path_other,
                    f"features_generated_p{other_pert_id}_org{organelle_idx}.npy",
                )

                real_features_other_org = None
                generated_features_other_org = None

                if os.path.exists(real_features_other_org_path):
                    real_features_other_org = np.load(real_features_other_org_path)
                else:
                    print(
                        f"Real feature file for p{other_pert_id}, organelle {organelle_idx} not found at {real_features_other_org_path}."
                    )

                if os.path.exists(generated_features_other_org_path):
                    generated_features_other_org = np.load(
                        generated_features_other_org_path
                    )
                else:
                    print(
                        f"Generated feature file for p{other_pert_id}, organelle {organelle_idx} not found at {generated_features_other_org_path}."
                    )

                if (
                    real_features_other_org is None or real_features_other_org.size == 0
                ) and (
                    generated_features_other_org is None
                    or generated_features_other_org.size == 0
                ):
                    print(
                        f"No features found for p{other_pert_id}, organelle {organelle_idx}. Skipping comparison."
                    )
                    continue

                print(
                    f"Loaded features for p{other_pert_id}, organelle {organelle_idx}"
                )

                # New Dual Comparison Plot for the current organelle
                plot_pca_dual_perturbation_comparison(
                    real_features_1138_org,
                    generated_features_1138_org,
                    real_features_other_org,
                    generated_features_other_org,
                    REFERENCE_PERT_ID,
                    other_pert_id,
                    organelle_idx,  # Pass organelle_idx
                    color_map=GLOBAL_PERTURBATION_COLOR_MAP,
                    type_markers_map=GLOBAL_TYPE_MARKERS,
                    # base_save_path_template is already updated in function definition
                )

                # --- ATE Calculation for the current organelle ---
                ate_real_org = np.nan
                if (
                    real_features_1138_org is not None
                    and real_features_other_org is not None
                    and real_features_1138_org.size > 0
                    and real_features_other_org.size > 0
                ):
                    mean_real_control_org = np.mean(real_features_1138_org, axis=0)
                    mean_real_other_org = np.mean(real_features_other_org, axis=0)
                    ate_real_org = (
                        np.linalg.norm(mean_real_control_org - mean_real_other_org) ** 2
                    )
                    print(
                        f"ATE_real (p{REFERENCE_PERT_ID} vs p{other_pert_id}, Organelle {organelle_idx}): {ate_real_org:.4f}"
                    )
                else:
                    print(
                        f"Could not calculate ATE_real for p{other_pert_id}, Organelle {organelle_idx} due to missing/empty real features."
                    )

                ate_generated_org = np.nan
                if (
                    generated_features_1138_org is not None
                    and generated_features_other_org is not None
                    and generated_features_1138_org.size > 0
                    and generated_features_other_org.size > 0
                ):
                    mean_generated_control_org = np.mean(
                        generated_features_1138_org, axis=0
                    )
                    mean_generated_other_org = np.mean(
                        generated_features_other_org, axis=0
                    )
                    ate_generated_org = (
                        np.linalg.norm(
                            mean_generated_control_org - mean_generated_other_org
                        )
                        ** 2
                    )
                    print(
                        f"ATE_generated (p{REFERENCE_PERT_ID} vs p{other_pert_id}, Organelle {organelle_idx}): {ate_generated_org:.4f}"
                    )
                else:
                    print(
                        f"Could not calculate ATE_generated for p{other_pert_id}, Organelle {organelle_idx} due to missing/empty generated features."
                    )

                ate_scores_list.append(
                    {
                        "comparison": f"p{REFERENCE_PERT_ID}_vs_p{other_pert_id}",
                        "organelle_id": organelle_idx,  # Add organelle ID
                        "ATE_real": ate_real_org,
                        "ATE_generated": ate_generated_org,
                    }
                )
    else:
        print(
            f"Features for reference perturbation ID {REFERENCE_PERT_ID} (all organelles) were not found. Skipping comparisons."
        )

    # --- PCA for All Features Combined (per organelle) ---
    print(f"\n\n{'='*80}")
    print("Starting PCA for all aggregated features (per organelle)")
    print(f"{'='*80}")

    for organelle_idx in range(NUM_ORGANELLES):
        print(f"\n--- Aggregated PCA for Organelle ID: {organelle_idx} ---")

        base_organelle_save_path = (
            f"ophenom_qual_results_organelle/organelle_{organelle_idx}"
        )
        aggregated_pca_save_path = os.path.join(
            base_organelle_save_path, "pca_aggregated"
        )
        os.makedirs(aggregated_pca_save_path, exist_ok=True)

        if (
            not all_real_features_per_org_lists[organelle_idx]
            and not all_generated_features_per_org_lists[organelle_idx]
        ):
            print(
                f"No features collected for organelle {organelle_idx} to generate aggregated PCA plot."
            )
            continue

        final_all_real_features_org = (
            np.vstack(all_real_features_per_org_lists[organelle_idx])
            if all_real_features_per_org_lists[organelle_idx]
            else np.array([])
        )
        final_all_generated_features_org = (
            np.vstack(all_generated_features_per_org_lists[organelle_idx])
            if all_generated_features_per_org_lists[organelle_idx]
            else np.array([])
        )

        pert_id_labels_all_real_org = all_pert_ids_for_real_per_org_lists[organelle_idx]
        pert_id_labels_all_generated_org = all_pert_ids_for_generated_per_org_lists[
            organelle_idx
        ]

        # Combine features for the current organelle
        features_to_combine_agg = []
        type_labels_all_org = []
        pert_id_labels_all_org = []

        if final_all_real_features_org.size > 0:
            features_to_combine_agg.append(final_all_real_features_org)
            type_labels_all_org.extend(["Real"] * len(final_all_real_features_org))
            pert_id_labels_all_org.extend(pert_id_labels_all_real_org)

        if final_all_generated_features_org.size > 0:
            features_to_combine_agg.append(final_all_generated_features_org)
            type_labels_all_org.extend(
                ["Generated"] * len(final_all_generated_features_org)
            )
            pert_id_labels_all_org.extend(pert_id_labels_all_generated_org)

        if not features_to_combine_agg:
            print(f"No features to aggregate for organelle {organelle_idx}. Skipping.")
            continue

        combined_all_features_org = np.vstack(features_to_combine_agg)

        if combined_all_features_org.shape[0] > 1:  # Need at least 2 samples for PCA
            scaler_all_org = StandardScaler()
            scaled_all_features_org = scaler_all_org.fit_transform(
                combined_all_features_org
            )

            pca_all_org = PCA(n_components=2)
            embedding_all_org = pca_all_org.fit_transform(scaled_all_features_org)
            var_explained_all_org = pca_all_org.explained_variance_ratio_ * 100

            plot_all_df_org = pd.DataFrame(
                {
                    "PC1": embedding_all_org[:, 0],
                    "PC2": embedding_all_org[:, 1],
                    "Type": type_labels_all_org,
                    "PerturbationID": pert_id_labels_all_org,
                }
            )

            plt.figure(figsize=(12, 10))  # Adjusted from 14,10 to 12,10 for consistency

            # Palette and markers are global
            sns.scatterplot(
                data=plot_all_df_org,
                x="PC1",
                y="PC2",
                hue="PerturbationID",
                style="Type",
                palette=GLOBAL_PERTURBATION_COLOR_MAP,
                markers=GLOBAL_TYPE_MARKERS,
                alpha=0.7,
                s=80,
                edgecolor=None,
            )

            plt.xlabel(
                f"PC1 ({var_explained_all_org[0]:.1f}%)", fontsize=LABEL_FONTSIZE
            )
            plt.ylabel(
                f"PC2 ({var_explained_all_org[1]:.1f}%)", fontsize=LABEL_FONTSIZE
            )
            plt.title(
                f"Aggregated PCA by Perturbation - {organelle_id_to_name[organelle_idx]}",  # Modified title
                fontsize=TITLE_FONTSIZE,
            )
            legend_obj_aggregated_org = plt.legend(loc="upper right")
            for text_obj in legend_obj_aggregated_org.findobj(plt.Text):
                current_text = text_obj.get_text()
                if current_text == "PerturbationID" or current_text == "Type":
                    text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)
                else:
                    text_obj.set_fontsize(LEGEND_FONTSIZE)

            plt.xticks(fontsize=TICK_FONTSIZE)
            plt.yticks(fontsize=TICK_FONTSIZE)
            # plt.tight_layout(rect=[0, 0, 0.85, 1])

            output_file_all_org = os.path.join(
                aggregated_pca_save_path,
                f"pca_all_perturbations_combined_org{organelle_idx}.pdf",  # Modified filename
            )
            plt.savefig(output_file_all_org, dpi=600)
            plt.close()
            print(
                f"Saved aggregated PCA plot for organelle {organelle_idx} to {output_file_all_org}"
            )
        else:
            print(
                f"Not enough data (less than 2 samples) for organelle {organelle_idx} to generate aggregated PCA plot."
            )
    # else: # This else was for the old aggregated logic, removing
    # print("No features collected to generate aggregated PCA plot.") # Old message

    # --- Print ATE Scores Summary ---
    if ate_scores_list:
        print(f"\n\n{'='*80}")
        print("Average Treatment Effect (ATE) Scores Summary (per Organelle)")
        print(f"Control Perturbation ID: {REFERENCE_PERT_ID}")
        print(f"{'='*80}")
        ate_df = pd.DataFrame(ate_scores_list)
        print(ate_df.to_string())

        # Save ATE scores to CSV
        ate_base_dir = "ophenom_qual_results_organelle"
        os.makedirs(
            ate_base_dir, exist_ok=True
        )  # Ensure base directory for ATE scores exists
        ate_csv_path = os.path.join(
            ate_base_dir, "ate_scores_organelle.csv"
        )  # Modified path
        ate_df.to_csv(ate_csv_path, index=False)
        print(f"\nATE scores saved to {ate_csv_path}")

    print("\n\nScript finished.")
