import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from diffusers.models import AutoencoderKL
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch
from sc_perturb.models.sit import SiT_models
from sc_perturb.openphenom import OpenPhenomEncoder
from sc_perturb.utils.generation_utils import generate_perturbation_matched_samples
from sc_perturb.utils.utils import load_encoders
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.

    Args:
        generated_path: Path to the directory containing generated data
        perturbation_id: Perturbation ID to filter by
        cell_type_id: Cell type ID to filter by (0-3)

    Returns:
        List of file paths matching both the perturbation and cell type
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def load_numpy_files(file_paths, max_samples):
    """
    Load a random subset of numpy files into a torch tensor.

    Args:
        file_paths: List of numpy file paths to load
        max_samples: Maximum number of samples to load

    Returns:
        Torch tensor containing the loaded data
    """
    # Randomly sample file paths if there are more than max_samples
    if len(file_paths) > max_samples:
        file_paths = random.sample(file_paths, max_samples)

    # Load the numpy files
    data = []
    for file_path in tqdm(file_paths, desc="Loading numpy files"):
        try:
            img = np.load(file_path)
            data.append(torch.from_numpy(img).float())
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Stack into a single tensor
    return torch.stack(data) if data else None


def create_cell_type_metadata(num_samples=500, perturbation_id=1138, cell_type=1):
    """
    Creates a perturbation metadata list with all cell types set to 1.

    Args:
        num_samples: Number of metadata entries to create
        perturbation_id: The perturbation ID to use for all entries

    Returns:
        List of metadata dictionaries with cell_type_id set to 1
    """
    perturbation_metadata = []

    for i in range(num_samples):
        metadata_entry = {
            "perturbation_id": perturbation_id,
            "cell_type_id": 1,  # All cell types set to 1 as requested
            "is_generated": False,  # This is typically False for real data
        }
        perturbation_metadata.append(metadata_entry)

    print(f"Created perturbation metadata with {len(perturbation_metadata)} entries")
    print(f"All entries have cell_type_id set to 1")

    return perturbation_metadata


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, 180, or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


def plot_pca_comparison(
    features1,
    features2,
    label1_desc,
    label2_desc,
    title_pert_id1,
    title_pert_id2,
    plot_title_prefix,
    output_filename_prefix,
    base_save_path="ophenom_qual_results/pca_comparisons",
):
    """
    Generates and saves a PCA plot comparing two sets of features.
    """
    os.makedirs(base_save_path, exist_ok=True)

    combined_features = np.vstack((features1, features2))
    # Create labels for the plot legend
    labels_plot = np.array(
        [f"{label1_desc} (p{title_pert_id1})"] * len(features1)
        + [f"{label2_desc} (p{title_pert_id2})"] * len(features2)
    )

    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(combined_features)

    pca = PCA(n_components=2)
    embedding = pca.fit_transform(scaled_features)
    var_explained = pca.explained_variance_ratio_ * 100

    plot_df = pd.DataFrame(
        {
            "PC1": embedding[:, 0],
            "PC2": embedding[:, 1],
            "Type": labels_plot,  # Use the descriptive labels for hue
        }
    )

    plt.figure(figsize=(12, 10))
    # Dynamically create palette based on unique labels
    unique_plot_labels = plot_df["Type"].unique()
    palette = {ul: sns.color_palette()[i] for i, ul in enumerate(unique_plot_labels)}

    sns.scatterplot(
        data=plot_df,
        x="PC1",
        y="PC2",
        hue="Type",  # Hue is based on the combined descriptive label
        palette=palette,
        alpha=0.7,
        s=50,
        edgecolor=None,
    )

    plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)")
    plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)")
    plt.title(f"PCA: {plot_title_prefix} - p{title_pert_id1} vs p{title_pert_id2}")
    plt.legend(loc="upper right")
    # plt.tight_layout()

    output_file = os.path.join(
        base_save_path,
        f"{output_filename_prefix}_p{title_pert_id1}_vs_p{title_pert_id2}.pdf",
    )
    plt.savefig(output_file, dpi=600)
    plt.close()
    print(f"Saved PCA comparison plot to {output_file}")


def plot_pca_dual_perturbation_comparison(
    real_features_ref,
    generated_features_ref,
    real_features_other,
    generated_features_other,
    ref_pert_id,
    other_pert_id,
    color_map,  # Added
    type_markers_map,  # Added
    base_save_path="ophenom_qual_results/pca_comparisons_dual",
):
    """
    Generates and saves a PCA plot comparing real/generated features of two perturbations.
    Hue by PerturbationID, Style by Type (Real/Generated).
    """
    os.makedirs(base_save_path, exist_ok=True)

    # Combine all features
    features_list = []
    pert_labels_list = []
    type_labels_list = []

    # Reference perturbation data
    if real_features_ref is not None and real_features_ref.size > 0:
        features_list.append(real_features_ref)
        pert_labels_list.extend([f"p{ref_pert_id}"] * len(real_features_ref))
        type_labels_list.extend(["Real"] * len(real_features_ref))
    if generated_features_ref is not None and generated_features_ref.size > 0:
        features_list.append(generated_features_ref)
        pert_labels_list.extend([f"p{ref_pert_id}"] * len(generated_features_ref))
        type_labels_list.extend(["Generated"] * len(generated_features_ref))

    # Other perturbation data
    if real_features_other is not None and real_features_other.size > 0:
        features_list.append(real_features_other)
        pert_labels_list.extend([f"p{other_pert_id}"] * len(real_features_other))
        type_labels_list.extend(["Real"] * len(real_features_other))
    if generated_features_other is not None and generated_features_other.size > 0:
        features_list.append(generated_features_other)
        pert_labels_list.extend([f"p{other_pert_id}"] * len(generated_features_other))
        type_labels_list.extend(["Generated"] * len(generated_features_other))

    if not features_list:
        print(
            f"No features provided for PCA dual comparison between p{ref_pert_id} and p{other_pert_id}. Skipping."
        )
        return

    combined_features = np.vstack(features_list)

    if combined_features.shape[0] < 2:
        print(
            f"Not enough samples ({combined_features.shape[0]}) for PCA dual comparison between p{ref_pert_id} and p{other_pert_id}. Skipping."
        )
        return

    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(combined_features)

    pca = PCA(n_components=2)
    embedding = pca.fit_transform(scaled_features)
    var_explained = pca.explained_variance_ratio_ * 100

    plot_df = pd.DataFrame(
        {
            "PC1": embedding[:, 0],
            "PC2": embedding[:, 1],
            "PerturbationID": pert_labels_list,
            "Type": type_labels_list,
        }
    )

    plt.figure(figsize=(12, 10))

    # Use the passed-in color_map and type_markers_map
    current_perturbations = [f"p{ref_pert_id}", f"p{other_pert_id}"]
    active_color_palette = {
        p: color_map.get(p) for p in current_perturbations if p in color_map
    }

    sns.scatterplot(
        data=plot_df,
        x="PC1",
        y="PC2",
        hue="PerturbationID",
        style="Type",
        palette=active_color_palette,  # Use filtered global color map
        markers=type_markers_map,  # Use global type markers
        alpha=0.7,
        s=90,  # Increased marker size
        edgecolor=None,
    )

    plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)")
    plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)")
    plt.title(f"PCA: p{ref_pert_id} (Real/Gen) vs p{other_pert_id} (Real/Gen)")
    plt.legend(loc="upper right")
    # plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout for legend

    output_file = os.path.join(
        base_save_path, f"pca_dual_p{ref_pert_id}_vs_p{other_pert_id}.pdf"
    )
    plt.savefig(output_file, dpi=600)
    plt.close()
    print(f"Saved dual PCA comparison plot to {output_file}")


if __name__ == "__main__":
    seed = 0
    MANUAL_GENERATION = False
    cell_type_id = 1
    seed_everything(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Define font sizes for plots
    TITLE_FONTSIZE = 26
    LABEL_FONTSIZE = 22
    LEGEND_FONTSIZE = 20
    LEGEND_TITLE_FONTSIZE = 20
    TICK_FONTSIZE = 18

    # Sample 4 random perturbation IDs out of 1138
    # all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    # sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    encoder = OpenPhenomEncoder().to(device)
    # Iterate through each perturbation ID and calculate metrics
    results = []
    real_features_1138 = None
    generated_features_1138 = None
    REFERENCE_PERT_ID = 1138

    # --- Create Global Color Map and Marker Map for Consistent Plotting ---
    GLOBAL_PERTURBATION_COLOR_MAP = {}
    # Ensure sampled_perturbation_ids are actual numeric IDs if they are not already
    # For this example, assuming sampled_perturbation_ids contains numeric IDs like [1138, 100, 200]
    # If they are strings like ['p1138'], adjust accordingly.

    # Get unique sorted numeric perturbation IDs to ensure consistent color assignment
    # This assumes sampled_perturbation_ids contains the numeric values
    unique_numeric_ids = sorted(list(set(sampled_perturbation_ids)))

    color_palette_list = sns.color_palette()  # Get a list of colors
    color_idx_non_ref = 0

    for pid_numeric in unique_numeric_ids:
        pid_str = f"p{pid_numeric}"  # Construct string ID like 'p1138'
        if pid_numeric == REFERENCE_PERT_ID:
            GLOBAL_PERTURBATION_COLOR_MAP[pid_str] = "gray"
        else:
            # Cycle through the seaborn palette for non-reference perturbations
            GLOBAL_PERTURBATION_COLOR_MAP[pid_str] = color_palette_list[
                color_idx_non_ref % len(color_palette_list)
            ]
            color_idx_non_ref += 1

    GLOBAL_TYPE_MARKERS = {"Real": "o", "Generated": "X"}
    print(f"Global Perturbation Color Map: {GLOBAL_PERTURBATION_COLOR_MAP}")
    print(f"Global Type Markers: {GLOBAL_TYPE_MARKERS}")
    # --- End Global Map Creation ---

    all_real_features_list = []
    all_generated_features_list = []
    all_pert_ids_for_real = []
    all_pert_ids_for_generated = []
    ate_scores_list = []  # To store ATE scores

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id, cell_type_id=cell_type_id
        )

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
            continue

        # Get real images
        real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(f"Found {len(real_images)} real images for perturbation ID {pert_id}")
        # Convert to tensor
        real_images_tensor = torch.stack(real_images)

        # Find all generated files for this perturbation
        cell_types = [cell_type_id]  # [0, 1, 2, 3]
        generated_files = []
        for cell_type_id in cell_types:
            generated_files_ct = find_generated_files_by_perturbation_and_celltype(
                generated_path,
                pert_id,
                cell_type_id,
            )
            generated_files.extend(generated_files_ct)
        print(
            f"Found {len(generated_files)} generated files for perturbation ID {pert_id}"
        )

        if not generated_files:
            print(f"No generated data found for perturbation ID {pert_id}")
            continue

        # Load generated images (sample up to NUM_SAMPLES)
        generated_images_tensor = load_numpy_files(
            generated_files, max_samples=NUM_SAMPLES
        )

        if generated_images_tensor is None:
            print(f"Failed to load generated images for perturbation ID {pert_id}")
            continue

        print(f"Extracting OpenPhenom features for perturbation ID {pert_id}")
        # Create directory for saving results if it doesn't exist
        os.makedirs("ophenom_qual_results/pca", exist_ok=True)

        # Extract OpenPhenom features for real and generated images
        with torch.no_grad():
            real_features = encoder(real_images_tensor.to(device)).cpu().numpy()
            generated_features = (
                encoder(generated_images_tensor.to(device)).cpu().numpy()
            )
        breakpoint()
        print(
            f"Extracted features shapes - Real: {real_features.shape}, Generated: {generated_features.shape}"
        )

        # Store features for the reference perturbation ID
        if pert_id == REFERENCE_PERT_ID:
            real_features_1138 = real_features
            generated_features_1138 = generated_features
            print(f"Stored features for reference perturbation ID {REFERENCE_PERT_ID}")

        # Store all features for aggregated PCA
        if real_features is not None and real_features.size > 0:
            all_real_features_list.append(real_features)
            all_pert_ids_for_real.extend(
                [f"p{pert_id}"] * len(real_features)
            )  # Store with 'p' prefix for clarity in plot
        if generated_features is not None and generated_features.size > 0:
            all_generated_features_list.append(generated_features)
            all_pert_ids_for_generated.extend(
                [f"p{pert_id}"] * len(generated_features)
            )  # Store with 'p' prefix

        # Combine features for PCA
        combined_features = np.vstack((real_features, generated_features))
        # Create labels to distinguish real from generated
        labels = np.array(
            ["Real"] * len(real_features) + ["Generated"] * len(generated_features)
        )

        # Apply PCA
        # Standardize the features
        scaler = StandardScaler()
        scaled_features = scaler.fit_transform(combined_features)

        # Apply PCA
        pca = PCA(n_components=2)
        embedding = pca.fit_transform(scaled_features)
        var_explained = pca.explained_variance_ratio_ * 100

        # Create DataFrame for plotting
        import pandas as pd

        plot_df = pd.DataFrame(
            {
                "PC1": embedding[:, 0],
                "PC2": embedding[:, 1],
                "Type": labels,
                "Perturbation": f"p{pert_id}",
            }
        )

        # Create PCA plot
        plt.figure(figsize=(12, 10))
        palette = {"Real": "blue", "Generated": "red"}

        # Create scatter plot
        sns.scatterplot(
            data=plot_df,
            x="PC1",
            y="PC2",
            hue="Type",
            palette=palette,
            alpha=0.7,
            s=50,
            edgecolor=None,
        )

        plt.xlabel(f"PC1 ({var_explained[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.ylabel(f"PC2 ({var_explained[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.title(
            f"PCA of OpenPhenom Features - Perturbation {pert_id}",
            fontsize=TITLE_FONTSIZE,
        )

        # Adjust legend font sizes
        legend_obj_individual = plt.legend(loc="upper right")
        for text_obj in legend_obj_individual.findobj(plt.Text):
            current_text = text_obj.get_text()
            if current_text == "Type":  # Title for the hue
                text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)
            else:  # Legend item labels
                text_obj.set_fontsize(LEGEND_FONTSIZE)

        plt.xticks(fontsize=TICK_FONTSIZE)
        plt.yticks(fontsize=TICK_FONTSIZE)
        # plt.tight_layout()

        # Save the plot
        output_file = f"ophenom_qual_results/pca/pca_perturbation_{pert_id}.pdf"
        plt.savefig(output_file, dpi=600)
        plt.close()

        print(f"Saved PCA plot to {output_file}")

        # Also save the features for potential future use
        np.save(f"ophenom_qual_results/pca/features_real_p{pert_id}.npy", real_features)
        np.save(
            f"ophenom_qual_results/pca/features_generated_p{pert_id}.npy",
            generated_features,
        )

    # --- Comparisons against Perturbation 1138 ---
    if real_features_1138 is not None and generated_features_1138 is not None:
        print(f"\n\n{'='*80}")
        print(f"Starting comparisons against Perturbation ID {REFERENCE_PERT_ID}")
        print(f"{'='*80}")

        for other_pert_id in sampled_perturbation_ids:
            if other_pert_id == REFERENCE_PERT_ID:
                continue

            print(f"\n--- Comparing p{REFERENCE_PERT_ID} with p{other_pert_id} ---")

            # Load features for the "other" perturbation
            try:
                real_features_other_path = (
                    f"ophenom_qual_results/pca/features_real_p{other_pert_id}.npy"
                )
                generated_features_other_path = (
                    f"ophenom_qual_results/pca/features_generated_p{other_pert_id}.npy"
                )

                if not (
                    os.path.exists(real_features_other_path)
                    and os.path.exists(generated_features_other_path)
                ):
                    print(
                        f"Feature files for p{other_pert_id} not found at {real_features_other_path} or {generated_features_other_path}. Skipping comparison."
                    )
                    continue

                real_features_other = np.load(real_features_other_path)
                generated_features_other = np.load(generated_features_other_path)
                print(f"Loaded features for p{other_pert_id}")

            except FileNotFoundError:
                print(
                    f"Feature files not found for perturbation {other_pert_id}. Skipping comparison."
                )
                continue
            except Exception as e:
                print(
                    f"Error loading features for perturbation {other_pert_id}: {e}. Skipping comparison."
                )
                continue

            # New Dual Comparison Plot
            plot_pca_dual_perturbation_comparison(
                real_features_1138,
                generated_features_1138,
                real_features_other,
                generated_features_other,
                REFERENCE_PERT_ID,
                other_pert_id,
                color_map=GLOBAL_PERTURBATION_COLOR_MAP,  # Pass global map
                type_markers_map=GLOBAL_TYPE_MARKERS,  # Pass global map
            )

            # --- ATE Calculation ---
            if (
                real_features_1138 is not None
                and real_features_other is not None
                and real_features_1138.size > 0
                and real_features_other.size > 0
            ):
                mean_real_control = np.mean(real_features_1138, axis=0)
                mean_real_other = np.mean(real_features_other, axis=0)
                ate_real = np.linalg.norm(mean_real_control - mean_real_other) ** 2
                print(
                    f"ATE_real (p{REFERENCE_PERT_ID} vs p{other_pert_id}): {ate_real:.4f}"
                )
            else:
                ate_real = np.nan
                print(
                    f"Could not calculate ATE_real for p{other_pert_id} due to missing/empty real features."
                )

            if (
                generated_features_1138 is not None
                and generated_features_other is not None
                and generated_features_1138.size > 0
                and generated_features_other.size > 0
            ):
                mean_generated_control = np.mean(generated_features_1138, axis=0)
                mean_generated_other = np.mean(generated_features_other, axis=0)
                ate_generated = (
                    np.linalg.norm(mean_generated_control - mean_generated_other) ** 2
                )
                print(
                    f"ATE_generated (p{REFERENCE_PERT_ID} vs p{other_pert_id}): {ate_generated:.4f}"
                )
            else:
                ate_generated = np.nan
                print(
                    f"Could not calculate ATE_generated for p{other_pert_id} due to missing/empty generated features."
                )

            ate_scores_list.append(
                {
                    "comparison": f"p{REFERENCE_PERT_ID}_vs_p{other_pert_id}",
                    "ATE_real": ate_real,
                    "ATE_generated": ate_generated,
                }
            )

    else:
        print(
            f"Features for reference perturbation ID {REFERENCE_PERT_ID} were not found. Skipping comparisons."
        )

    # --- PCA for All Features Combined ---
    if all_real_features_list and all_generated_features_list:
        print(f"\n\n{'='*80}")
        print("Starting PCA for all aggregated features")
        print(f"{'='*80}")

        final_all_real_features = np.vstack(all_real_features_list)
        final_all_generated_features = np.vstack(all_generated_features_list)

        combined_all_features = np.vstack(
            (final_all_real_features, final_all_generated_features)
        )

        type_labels_all = ["Real"] * len(final_all_real_features) + ["Generated"] * len(
            final_all_generated_features
        )
        pert_id_labels_all = all_pert_ids_for_real + all_pert_ids_for_generated

        if combined_all_features.shape[0] > 1:  # Need at least 2 samples for PCA
            scaler_all = StandardScaler()
            scaled_all_features = scaler_all.fit_transform(combined_all_features)

            pca_all = PCA(n_components=2)
            embedding_all = pca_all.fit_transform(scaled_all_features)
            var_explained_all = pca_all.explained_variance_ratio_ * 100

            plot_all_df = pd.DataFrame(
                {
                    "PC1": embedding_all[:, 0],
                    "PC2": embedding_all[:, 1],
                    "Type": type_labels_all,
                    "PerturbationID": pert_id_labels_all,
                }
            )

            plt.figure(figsize=(12, 10))

            unique_perturbations = sorted(plot_all_df["PerturbationID"].unique())
            # palette = sns.color_palette(n_colors=len(unique_perturbations))
            # Create a mapping from perturbation ID to color
            # perturbation_palette = {
            #     pert_id: palette[i] for i, pert_id in enumerate(unique_perturbations)
            # }
            perturbation_palette = {}
            color_idx = 0
            for pert_id_str in unique_perturbations:
                # Extract numeric part of pert_id_str, e.g., 'p1138' -> 1138
                try:
                    numeric_pert_id = int(pert_id_str.replace("p", ""))
                    if numeric_pert_id == 1138:
                        perturbation_palette[pert_id_str] = "gray"
                    else:
                        perturbation_palette[pert_id_str] = sns.color_palette()[
                            color_idx
                        ]
                        color_idx = (color_idx + 1) % len(
                            sns.color_palette()
                        )  # Cycle through palette colors
                except ValueError:
                    # Handle cases where pert_id_str might not be in the expected format 'p<number>'
                    perturbation_palette[pert_id_str] = sns.color_palette()[color_idx]
                    color_idx = (color_idx + 1) % len(sns.color_palette())

            sns.scatterplot(
                data=plot_all_df,
                x="PC1",
                y="PC2",
                hue="PerturbationID",
                style="Type",
                palette=GLOBAL_PERTURBATION_COLOR_MAP,  # Use global color map
                markers=GLOBAL_TYPE_MARKERS,  # Use global type markers
                alpha=0.7,
                s=80,  # Increased marker size
                edgecolor=None,
            )

            plt.xlabel(f"PC1 ({var_explained_all[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
            plt.ylabel(f"PC2 ({var_explained_all[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
            plt.title(
                "PCA of OpenPhenom Features by Perturbation",
                fontsize=TITLE_FONTSIZE,
            )

            # Adjust legend font sizes
            legend_obj_aggregated = plt.legend(loc="upper right")
            for text_obj in legend_obj_aggregated.findobj(plt.Text):
                current_text = text_obj.get_text()
                # Check if the text is one of the legend titles (e.g., "PerturbationID", "Type")
                if current_text == "PerturbationID" or current_text == "Type":
                    text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)
                else:  # These are the legend item labels
                    text_obj.set_fontsize(LEGEND_FONTSIZE)

            plt.xticks(fontsize=TICK_FONTSIZE)
            plt.yticks(fontsize=TICK_FONTSIZE)
            # plt.tight_layout(
            #     rect=[0, 0, 0.85, 1]
            # )  # Adjust layout to make space for legend

            aggregated_pca_save_path = "ophenom_qual_results/pca_aggregated"
            os.makedirs(aggregated_pca_save_path, exist_ok=True)
            output_file_all = os.path.join(
                aggregated_pca_save_path, "pca_all_perturbations_combined.pdf"
            )
            plt.savefig(output_file_all, dpi=600)
            plt.close()
            print(f"Saved aggregated PCA plot to {output_file_all}")
        else:
            print(
                "Not enough data (less than 2 samples) to generate aggregated PCA plot."
            )
    else:
        print("No features collected to generate aggregated PCA plot.")

    # --- Print ATE Scores Summary ---
    if ate_scores_list:
        print(f"\n\n{'='*80}")
        print("Average Treatment Effect (ATE) Scores Summary")
        print(f"Control Perturbation ID: {REFERENCE_PERT_ID}")
        print(f"{'='*80}")
        ate_df = pd.DataFrame(ate_scores_list)
        print(ate_df.to_string())

        # Save ATE scores to CSV
        ate_csv_path = "ophenom_qual_results/ate_scores.csv"
        os.makedirs(os.path.dirname(ate_csv_path), exist_ok=True)
        ate_df.to_csv(ate_csv_path, index=False)
        print(f"\nATE scores saved to {ate_csv_path}")

    print("\n\nScript finished.")
