import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import zscore
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def preprocess_features(df, feature_columns):
    """
    Standardize features using StandardScaler.
    Args:
        df: pandas DataFrame containing features and metadata.
        feature_columns: list of feature column names to standardize.
    Returns:
        df_scaled: DataFrame with standardized features (other columns unchanged).
    """
    scaler = StandardScaler()
    df_scaled = df.copy()
    df_scaled[feature_columns] = scaler.fit_transform(df_scaled[feature_columns])
    return df_scaled


def select_features(df, feature_columns, variance_thresh=1e-5, z_thresh=5):
    """
    Remove features with low variance and features with extreme outliers.
    Args:
        df: pandas DataFrame containing features and metadata.
        feature_columns: list of feature column names to filter.
        variance_thresh: float, threshold for removing low-variance features.
        z_thresh: float, z-score threshold for outlier-based feature removal.
    Returns:
        selected_columns: list of selected feature columns.
    """
    from sklearn.feature_selection import VarianceThreshold

    # Remove low-variance features
    selector = VarianceThreshold(threshold=variance_thresh)
    selector.fit(df[feature_columns])
    kept_columns = [
        feature_columns[i]
        for i in range(len(feature_columns))
        if selector.get_support()[i]
    ]

    # Remove features with extreme outliers (any value with |z| > z_thresh)
    zscores = np.abs(zscore(df[kept_columns], nan_policy="omit"))
    # If any value in a column exceeds z_thresh, drop that column
    mask = (zscores < z_thresh).all(axis=0)
    selected_columns = [col for col, keep in zip(kept_columns, mask) if keep]
    return selected_columns


def remove_highly_correlated_features(df, feature_columns, corr_thresh=0.85):
    """
    Remove features with a pairwise correlation above corr_thresh.
    Args:
        df: pandas DataFrame containing features.
        feature_columns: list of feature column names to filter.
        corr_thresh: float, correlation threshold for removal.
    Returns:
        reduced_columns: list of feature columns after removing highly correlated ones.
    """
    corr_matrix = df[feature_columns].corr().abs()
    upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
    to_drop = [column for column in upper.columns if any(upper[column] > corr_thresh)]
    reduced_columns = [col for col in feature_columns if col not in to_drop]
    return reduced_columns


def load_cp_features(base_path, perturbations, groups=("real", "generated")):
    """
    Load all Image.csv files for each perturbation under generated and real directories.
    Args:
        base_path: str, path to CP_outputs directory.
        perturbations: list of perturbation directory names (e.g., ["1108", ...]).
        groups: tuple of group names (default: ("real", "generated")).
    Returns:
        df_all: concatenated DataFrame with a new column 'Source' (real/generated) and 'Perturbation'.
        common_cols: list of columns present in all files (order preserved from the first file).
    """
    dfs = []
    for group in groups:
        group_path = os.path.join(base_path, group)
        for pert in perturbations:
            pert_dir = f"p{pert}" if not str(pert).startswith("p") else str(pert)
            pert_path = os.path.join(group_path, pert_dir, "Image.csv")
            if os.path.exists(pert_path):
                # low_memory=False for consistent dtype inference across files
                df = pd.read_csv(pert_path, low_memory=False)
                df["Source"] = group
                df["Perturbation"] = pert
                dfs.append(df)

    if not dfs:
        raise FileNotFoundError("No Image.csv files found in the specified structure.")

    # Deterministic column intersection: preserve the order from the first DataFrame
    first_cols = list(dfs[0].columns)
    common_cols = [c for c in first_cols if all(c in df.columns for df in dfs[1:])]

    # Reindex all DataFrames to the same ordered column set
    dfs = [df[common_cols] for df in dfs]

    df_all = pd.concat(dfs, ignore_index=True)
    return df_all, common_cols


# def plot_pca(
#     df, features, result_path, color_by="Perturbation", style_by="Source", prefix="all"
# ):
#     """
#     Plot PCA and t-SNE for the given dataframe.
#     Args:
#         df: DataFrame with features and metadata.
#         features: list of feature columns to use.
#         result_path: directory to save plots.
#         color_by: column to color points by.
#         style_by: column to style points by (e.g., Source: real/generated).
#         prefix: prefix for output filenames.
#     """
#     import os

#     import matplotlib.pyplot as plt
#     import seaborn as sns
#     from sklearn.decomposition import PCA
#     from sklearn.manifold import TSNE

#     if not os.path.exists(result_path):
#         os.makedirs(result_path)

#     # --- PCA: all points, color by perturbation, style by source ---
#     # pca = PCA(n_components=2)
#     pca = PCA(n_components=2, svd_solver="full", random_state=0)

#     pca_result = pca.fit_transform(df[features].values)
#     df["pca-one"] = pca_result[:, 0]
#     df["pca-two"] = pca_result[:, 1]
#     print(
#         f"[PCA] Explained variance ratio (2 components): {pca.explained_variance_ratio_}"
#     )
#     print(
#         f"[PCA] Total explained variance (2 components): {np.sum(pca.explained_variance_ratio_):.4f}"
#     )
#     plt.figure(figsize=(8, 6))
#     sns.scatterplot(
#         data=df, x="pca-one", y="pca-two", hue="Perturbation", style="Source", alpha=0.7
#     )
#     plt.title(f"PCA: all points (color=Perturbation, style=Source)")
#     plt.tight_layout()
#     plt.savefig(os.path.join(result_path, f"all_pca.png"), dpi=200)
#     plt.close()

#     # --- PCA: per perturbation, color by Source ---
#     for pert in df["Perturbation"].unique():
#         df_pert = df[df["Perturbation"] == pert]
#         plt.figure(figsize=(7, 5))
#         sns.scatterplot(data=df_pert, x="pca-one", y="pca-two", hue="Source", alpha=0.7)
#         plt.title(f"PCA: Perturbation {pert} (color=Source)")
#         plt.tight_layout()
#         plt.savefig(os.path.join(result_path, f"pca_pert_{pert}.png"), dpi=200)
#         plt.close()

#     return pca


# def plot_pca(
#     df, features, result_path, color_by="Perturbation", style_by="Source", prefix="all"
# ):
#     """
#     Plot PCA for the given dataframe.
#     Legends are titled 'PerturbationID' (hue) and 'Type' (style).
#     Uses global style constants if defined:
#       TITLE_FONTSIZE, LABEL_FONTSIZE, LEGEND_FONTSIZE, LEGEND_TITLE_FONTSIZE, TICK_FONTSIZE
#     And global maps if defined:
#       GLOBAL_PERTURBATION_COLOR_MAP (dict), GLOBAL_TYPE_MARKERS (dict)
#     """
#     import os

#     import matplotlib.pyplot as plt
#     import seaborn as sns
#     from sklearn.decomposition import PCA

#     os.makedirs(result_path, exist_ok=True)

#     # Fonts (fall back to sensible defaults if not defined globally)
#     TITLE_FONTSIZE = globals().get("TITLE_FONTSIZE", 26)
#     LABEL_FONTSIZE = globals().get("LABEL_FONTSIZE", 22)
#     LEGEND_FONTSIZE = globals().get("LEGEND_FONTSIZE", 20)
#     LEGEND_TITLE_FONTSIZE = globals().get("LEGEND_TITLE_FONTSIZE", 20)
#     TICK_FONTSIZE = globals().get("TICK_FONTSIZE", 18)

#     # --- PCA on features (deterministic) ---
#     pca = PCA(n_components=2, svd_solver="full", random_state=0)
#     pca_result = pca.fit_transform(df[features].values)
#     var_pct = pca.explained_variance_ratio_ * 100.0

#     # Store back (keeps old behavior)
#     df["pca-one"] = pca_result[:, 0]
#     df["pca-two"] = pca_result[:, 1]
#     print(
#         f"[PCA] Explained variance ratio (2 components): {pca.explained_variance_ratio_}"
#     )
#     print(
#         f"[PCA] Total explained variance (2 components): {np.sum(pca.explained_variance_ratio_):.4f}"
#     )

#     # Prepare a plotting view with renamed columns for legend titles
#     df_plot = df.rename(columns={color_by: "PerturbationID", style_by: "Type"})
#     hue_name, style_name = "PerturbationID", "Type"

#     # Build color palette for hue (PerturbationID)
#     levels_hue = pd.unique(df_plot[hue_name])
#     palette = None
#     if "GLOBAL_PERTURBATION_COLOR_MAP" in globals():
#         gmap = globals().get("GLOBAL_PERTURBATION_COLOR_MAP", {})
#         pal = {}
#         for lvl in levels_hue:
#             # accept either exact key or 'p{lvl}' style
#             key = lvl if lvl in gmap else f"p{lvl}"
#             if key in gmap:
#                 pal[lvl] = gmap[key]
#         if len(pal) == len(levels_hue):
#             palette = pal
#     if palette is None:
#         # fall back to seaborn palette with deterministic mapping
#         palette_list = sns.color_palette(n_colors=len(levels_hue))
#         palette = {lvl: palette_list[i] for i, lvl in enumerate(levels_hue)}

#     # Build markers for style (Type)
#     levels_style = pd.unique(df_plot[style_name])
#     if "GLOBAL_TYPE_MARKERS" in globals():
#         tmap = globals().get("GLOBAL_TYPE_MARKERS", {})
#         markers = {t: tmap.get(t, "o") for t in levels_style}
#     else:
#         # sensible defaults; handles either 'Real'/'Generated' or lowercase
#         default = {"Real": "o", "Generated": "X", "real": "o", "generated": "X"}
#         markers = {
#             t: default.get(t, ["o", "X", "s", "D", "^", "v"][i % 6])
#             for i, t in enumerate(levels_style)
#         }

#     # --- PCA scatter: all points ---
#     plt.figure(figsize=(12, 10))
#     ax = sns.scatterplot(
#         data=df_plot,
#         x="pca-one",
#         y="pca-two",
#         hue=hue_name,
#         style=style_name,
#         palette=palette,
#         markers=markers,
#         alpha=0.7,
#         s=80,
#         edgecolor=None,
#     )
#     plt.xlabel(f"PC1", fontsize=LABEL_FONTSIZE)
#     plt.ylabel(f"PC2", fontsize=LABEL_FONTSIZE)
#     plt.title(
#         "PCA of CellProfiler features by Perturbation (HUVEC)", fontsize=TITLE_FONTSIZE
#     )

#     # Legend styling: two section headers ('PerturbationID', 'Type') + items
#     leg = plt.legend(loc="upper right")
#     if leg is not None:
#         for text_obj in leg.findobj(plt.Text):
#             txt = text_obj.get_text()
#             if txt in ("PerturbationID", "Type"):
#                 text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)
#             else:
#                 text_obj.set_fontsize(LEGEND_FONTSIZE)

#     plt.xticks(fontsize=TICK_FONTSIZE)
#     plt.yticks(fontsize=TICK_FONTSIZE)
#     plt.tight_layout()
#     plt.savefig(os.path.join(result_path, f"{prefix}_pca.png"), dpi=600)
#     plt.close()

#     # --- PCA: per perturbation (legend should say 'Type') ---
#     for pert in levels_hue:
#         df_pert = df_plot[df_plot[hue_name] == pert]
#         if df_pert.empty:
#             continue

#         plt.figure(figsize=(12, 10))
#         # For per-perturbation view we color by Type to keep legend title as 'Type'
#         # (colors chosen deterministically; markers optional here)
#         # Build a small 2-color palette for the present Type levels
#         present_types = pd.unique(df_pert[style_name])
#         type_palette = {
#             t: {
#                 "Real": "#1f77b4",
#                 "Generated": "#d62728",
#                 "real": "#1f77b4",
#                 "generated": "#d62728",
#             }.get(t, sns.color_palette()[i % 10])
#             for i, t in enumerate(present_types)
#         }

#         sns.scatterplot(
#             data=df_pert,
#             x="pca-one",
#             y="pca-two",
#             hue=style_name,  # legend title becomes 'Type'
#             palette=type_palette,
#             alpha=0.75,
#             s=80,
#             edgecolor=None,
#         )

#         plt.xlabel(f"PC1", fontsize=LABEL_FONTSIZE)
#         plt.ylabel(f"PC2", fontsize=LABEL_FONTSIZE)
#         plt.title(
#             f"PCA of CellProfiler Features - Perturbation {pert} (HUVEC)",
#             fontsize=TITLE_FONTSIZE,
#         )

#         leg2 = plt.legend(title="Type", loc="upper right")
#         if leg2 is not None:
#             if leg2.get_title() is not None:
#                 leg2.get_title().set_fontsize(LEGEND_TITLE_FONTSIZE)
#             for t in leg2.get_texts():
#                 t.set_fontsize(LEGEND_FONTSIZE)

#         plt.xticks(fontsize=TICK_FONTSIZE)
#         plt.yticks(fontsize=TICK_FONTSIZE)
#         plt.tight_layout()
#         plt.savefig(os.path.join(result_path, f"pca_pert_{pert}.png"), dpi=600)
#         plt.close()

#     return pca


def plot_pca(
    df, features, result_path, color_by="Perturbation", style_by="Source", prefix="all",
    axes_edgewidth=3.0, gray_for_1138="#7f7f7f"
):
    """
    Plot PCA for the given dataframe.
    Legends are titled 'PerturbationID' (hue) and 'Type' (style).
    Uses global style constants if defined:
      TITLE_FONTSIZE, LABEL_FONTSIZE, LEGEND_FONTSIZE, LEGEND_TITLE_FONTSIZE, TICK_FONTSIZE
    And global maps if defined:
      GLOBAL_PERTURBATION_COLOR_MAP (dict), GLOBAL_TYPE_MARKERS (dict)
    """
    import os
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.decomposition import PCA

    os.makedirs(result_path, exist_ok=True)

    # Fonts (fall back to sensible defaults if not defined globally)
    TITLE_FONTSIZE = globals().get("TITLE_FONTSIZE", 26)
    LABEL_FONTSIZE = globals().get("LABEL_FONTSIZE", 22)
    LEGEND_FONTSIZE = globals().get("LEGEND_FONTSIZE", 20)
    LEGEND_TITLE_FONTSIZE = globals().get("LEGEND_TITLE_FONTSIZE", 20)
    TICK_FONTSIZE = globals().get("TICK_FONTSIZE", 18)

    # --- PCA on features (deterministic) ---
    pca = PCA(n_components=2, svd_solver="full", random_state=0)
    pca_result = pca.fit_transform(df[features].values)
    var_pct = pca.explained_variance_ratio_ * 100.0

    # Store back
    df = df.copy()
    df["pca-one"] = pca_result[:, 0]
    df["pca-two"] = pca_result[:, 1]
    print(f"[PCA] Explained variance ratio (2 components): {pca.explained_variance_ratio_}")
    print(f"[PCA] Total explained variance (2 components): {np.sum(pca.explained_variance_ratio_):.4f}")

    # Prepare plotting view
    df_plot = df.rename(columns={color_by: "PerturbationID", style_by: "Type"})
    hue_name, style_name = "PerturbationID", "Type"

    # Build color palette for hue (PerturbationID)
    levels_hue = pd.unique(df_plot[hue_name])

    # Priority 1: global map if it fully covers the levels
    palette = None
    if "GLOBAL_PERTURBATION_COLOR_MAP" in globals():
        gmap = globals().get("GLOBAL_PERTURBATION_COLOR_MAP", {})
        pal = {}
        for lvl in levels_hue:
            key = lvl if lvl in gmap else f"p{lvl}"
            if key in gmap:
                pal[lvl] = gmap[key]
        if len(pal) == len(levels_hue):
            palette = pal

    # Priority 2: fallback seaborn palette, then override 1138 -> gray
    if palette is None:
        palette_list = sns.color_palette(n_colors=len(levels_hue))
        palette = {lvl: palette_list[i] for i, lvl in enumerate(levels_hue)}

    # Force 1138 to gray (handle int or str levels)
    for k in (1138, "1138"):
        if k in palette:
            palette[k] = gray_for_1138

    # Build markers for style (Type)
    levels_style = pd.unique(df_plot[style_name])
    if "GLOBAL_TYPE_MARKERS" in globals():
        tmap = globals().get("GLOBAL_TYPE_MARKERS", {})
        markers = {t: tmap.get(t, "o") for t in levels_style}
    else:
        default = {"Real": "o", "Generated": "X", "real": "o", "generated": "X"}
        markers = {t: default.get(t, ["o", "X", "s", "D", "^", "v"][i % 6]) for i, t in enumerate(levels_style)}

    # --- PCA scatter: all points ---
    # sample 300 points if more 
    if len(df_plot) > 300:
        df_plot = df_plot.sample(n=300, random_state=0)
    plt.figure(figsize=(12, 10))
    ax = sns.scatterplot(
        data=df_plot,
        x="pca-one", y="pca-two",
        hue=hue_name, style=style_name,
        palette=palette, markers=markers,
        alpha=0.7, s=150, edgecolor=None,
    )

    # Thicken plot frame (axes spines)
    for spine in ax.spines.values():
        spine.set_linewidth(axes_edgewidth)

    # plt.xlabel("PC1", fontsize=LABEL_FONTSIZE)
    # plt.ylabel("PC2", fontsize=LABEL_FONTSIZE)
    plt.xlabel(f"PC1 ({var_pct[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
    plt.ylabel(f"PC2 ({var_pct[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
    plt.title("PCA of CellProfiler features by Perturbation (HUVEC)", fontsize=TITLE_FONTSIZE)

    # Legend styling
    leg = plt.legend(loc="upper right")
    if leg is not None:
        for text_obj in leg.findobj(plt.Text):
            txt = text_obj.get_text()
            if txt in ("PerturbationID", "Type"):
                text_obj.set_fontsize(LEGEND_TITLE_FONTSIZE)
            else:
                text_obj.set_fontsize(LEGEND_FONTSIZE)

    plt.xticks(fontsize=TICK_FONTSIZE)
    plt.yticks(fontsize=TICK_FONTSIZE)
    plt.tight_layout()
    plt.savefig(os.path.join(result_path, f"{prefix}_pca.png"), dpi=600)
    plt.close()

    # --- PCA: per-perturbation view (legend says 'Type') ---
    for pert in levels_hue:
        df_pert = df_plot[df_plot[hue_name] == pert]
        if df_pert.empty:
            continue

        plt.figure(figsize=(12, 10))
        present_types = pd.unique(df_pert[style_name])
        type_palette = {
            t: {"Real": "#1f77b4", "Generated": "#d62728", "real": "#1f77b4", "generated": "#d62728"}.get(
                t, sns.color_palette()[i % 10]
            )
            for i, t in enumerate(present_types)
        }

        ax = sns.scatterplot(
            data=df_pert,
            x="pca-one", y="pca-two",
            hue=style_name, palette=type_palette,
            alpha=0.75, s=80, edgecolor=None,
        )

        for spine in ax.spines.values():
            spine.set_linewidth(axes_edgewidth)

        # plt.xlabel("PC1", fontsize=LABEL_FONTSIZE)
        # plt.ylabel("PC2", fontsize=LABEL_FONTSIZE)
        plt.xlabel(f"PC1 ({var_pct[0]:.1f}%)", fontsize=LABEL_FONTSIZE)
        plt.ylabel(f"PC2 ({var_pct[1]:.1f}%)", fontsize=LABEL_FONTSIZE)
        
        plt.title(f"PCA of CellProfiler Features - Perturbation {pert} (HUVEC)", fontsize=TITLE_FONTSIZE)

        leg2 = plt.legend(title="Type", loc="upper right")
        if leg2 is not None:
            if leg2.get_title() is not None:
                leg2.get_title().set_fontsize(LEGEND_TITLE_FONTSIZE)
            for t in leg2.get_texts():
                t.set_fontsize(LEGEND_FONTSIZE)

        plt.xticks(fontsize=TICK_FONTSIZE)
        plt.yticks(fontsize=TICK_FONTSIZE)
        plt.tight_layout()
        plt.savefig(os.path.join(result_path, f"pca_pert_{pert}.png"), dpi=600)
        plt.close()

    return pca



def train_evaluate_linear_classifier(
    df,
    features,
    result_path,
    label_col="Perturbation",
    source_col="Source",
    seeds=(7, 42, 1337),
    test_size=0.2,
    n_pca=32,
):
    """
    Runs 3 independent seeds per setting and reports mean accuracy with 95% t-CI.
    Settings:
      - real→real (stratified split)
      - gen→gen (stratified split)
      - gen→real (train on all gen, test on all real; seeds affect model randomness only)
      - both→real (split real, then train on real_train + all gen; test on real_test)
    Writes per-seed reports and a summary file per setting into result_path.
    """
    import os

    import numpy as np
    import pandas as pd
    from sklearn.decomposition import PCA
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import (
        accuracy_score,
        classification_report,
        confusion_matrix,
        f1_score,
    )
    from sklearn.model_selection import train_test_split
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler

    os.makedirs(result_path, exist_ok=True)

    real = df[df[source_col] == "real"].copy()
    gen = df[df[source_col] == "generated"].copy()
    X_real, y_real = real[features], real[label_col]
    X_gen, y_gen = gen[features], gen[label_col]

    # Optional safety check; remove if you truly guarantee this.
    assert set(y_real.unique()) == set(
        y_gen.unique()
    ), "Label sets differ (real vs generated)."

    def split_stratified(X, y, seed):
        """Stratified split, with graceful fallback if some classes are tiny."""
        try:
            return train_test_split(
                X, y, test_size=test_size, random_state=seed, stratify=y
            )
        except ValueError:
            # Fallback: unstratified split (rare; e.g., class with <2 samples)
            return train_test_split(X, y, test_size=test_size, random_state=seed)

    def run_once(X_train, y_train, X_test, y_test, split_name, seed):
        pipe = Pipeline(
            [
                ("scaler", StandardScaler()),
                ("pca", PCA(n_components=n_pca, random_state=seed)),
                (
                    "clf",
                    LogisticRegression(
                        max_iter=1000, random_state=seed, multi_class="multinomial"
                    ),
                ),
            ]
        )
        pipe.fit(X_train, y_train)
        y_pred = pipe.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        f1m = f1_score(y_test, y_pred, average="macro")
        labels = sorted(y_test.unique())
        cm = confusion_matrix(y_test, y_pred, labels=labels)
        evr = pipe.named_steps["pca"].explained_variance_ratio_

        # Per-seed report
        with open(
            os.path.join(result_path, f"clf_report_{split_name}_seed{seed}.txt"), "w"
        ) as f:
            f.write(f"Test Accuracy: {acc:.4f}\n")
            f.write(f"Macro-F1: {f1m:.4f}\n\n")
            f.write("Classification Report:\n")
            f.write(classification_report(y_test, y_pred, digits=4))
            f.write("\nConfusion Matrix (rows=true, cols=pred):\n")
            f.write(pd.DataFrame(cm, index=labels, columns=labels).to_string())
            f.write(f"\n\nPCA EVR (first {len(evr)}): {np.round(evr,4).tolist()}\n")
            f.write(f"Total EVR: {float(evr.sum()):.4f}\n")
        return acc, f1m

    def summarize(accs, f1s, split_name):
        # 95% t-interval over seeds (n small)
        accs = np.asarray(accs, dtype=float)
        f1s = np.asarray(f1s, dtype=float)
        n = len(accs)
        mean_acc = float(accs.mean())
        mean_f1 = float(f1s.mean())
        if n > 1:
            from math import sqrt

            try:
                from scipy.stats import t

                tcrit = float(t.ppf(0.975, df=n - 1))
            except Exception:
                # Fallback: tcrit for n=3 (df=2) ≈ 4.303; n=2 (df=1) ≈ 12.706
                tcrit = 4.303 if n == 3 else 12.706
            se_acc = float(accs.std(ddof=1) / sqrt(n))
            se_f1 = float(f1s.std(ddof=1) / sqrt(n))
            acc_lo, acc_hi = mean_acc - tcrit * se_acc, mean_acc + tcrit * se_acc
            f1_lo, f1_hi = mean_f1 - tcrit * se_f1, mean_f1 + tcrit * se_f1
        else:
            acc_lo = acc_hi = mean_acc
            f1_lo = f1_hi = mean_f1

        with open(
            os.path.join(result_path, f"clf_report_{split_name}_summary.txt"), "w"
        ) as f:
            f.write(f"Seeds: {list(seeds)}\n")
            f.write(f"Accuracy per seed: {np.round(accs,4).tolist()}\n")
            f.write(f"Macro-F1 per seed: {np.round(f1s,4).tolist()}\n\n")
            f.write(
                f"Mean Test Accuracy: {mean_acc:.4f}  (95% t-CI [{acc_lo:.4f}, {acc_hi:.4f}])\n"
            )
            f.write(
                f"Mean Macro-F1:      {mean_f1:.4f}  (95% t-CI [{f1_lo:.4f}, {f1_hi:.4f}])\n"
            )

    # -------- real→real --------
    accs, f1s = [], []
    for s in seeds:
        Xtr, Xte, ytr, yte = split_stratified(X_real, y_real, s)
        a, f = run_once(Xtr, ytr, Xte, yte, "real_real", s)
        accs.append(a)
        f1s.append(f)
    summarize(accs, f1s, "real_real")

    # -------- gen→gen --------
    accs, f1s = [], []
    for s in seeds:
        Xtr, Xte, ytr, yte = split_stratified(X_gen, y_gen, s)
        a, f = run_once(Xtr, ytr, Xte, yte, "gen_gen", s)
        accs.append(a)
        f1s.append(f)
    summarize(accs, f1s, "gen_gen")

    # -------- gen→real --------
    accs, f1s = [], []
    for s in seeds:
        a, f = run_once(X_gen, y_gen, X_real, y_real, "gen_real", s)
        accs.append(a)
        f1s.append(f)
    summarize(accs, f1s, "gen_real")

    # -------- both→real -------- (fit transforms on real_train only; train on real_train+all gen; test on real_test)
    accs, f1s = [], []
    for s in seeds:
        # split real
        Xr_tr, Xr_te, yr_tr, yr_te = split_stratified(X_real, y_real, s)

        # fit scaler on real_train only; transform real_train, real_test, and generated
        scaler = StandardScaler().fit(Xr_tr)
        Xr_tr_s = scaler.transform(Xr_tr)
        Xr_te_s = scaler.transform(Xr_te)
        Xg_s = scaler.transform(X_gen)

        # optional PCA: fit on training union (real_train ∪ gen) AFTER scaling
        evr = None
        if (n_pca is not None) and (n_pca > 0):
            X_union_for_pca = np.vstack([Xr_tr_s, Xg_s])

            # if n_pca is a float (e.g., 0.95), use full SVD; random_state is ignored in this mode
            if isinstance(n_pca, float):
                pca = PCA(n_components=n_pca, svd_solver="full")
            else:
                pca = PCA(
                    n_components=min(n_pca, X_union_for_pca.shape[1]), svd_solver="full"
                )

            pca.fit(X_union_for_pca)  # fit on training union only
            Xr_tr_s = pca.transform(Xr_tr_s)
            Xr_te_s = pca.transform(Xr_te_s)
            Xg_s = pca.transform(Xg_s)
            evr = pca.explained_variance_ratio_

        # randomly subsample generated to match the number of real_train samples
        n_real = Xr_tr_s.shape[0]
        rng = np.random.RandomState(s)  # use seed for reproducibility
        idx = rng.choice(len(Xg_s), size=n_real, replace=False)
        Xg_sub = Xg_s[idx]
        yg_sub = y_gen.iloc[idx]

        # concatenate train: real_train + sampled generated
        X_tr = np.vstack([Xr_tr_s, Xg_sub])
        y_tr = pd.concat([yr_tr, yg_sub], axis=0)

        # X_tr = np.vstack([Xr_tr_s, Xg_s])
        # y_tr = pd.concat([yr_tr, y_gen], axis=0)

        # train classifier and evaluate on real_test
        clf = LogisticRegression(
            max_iter=1000, random_state=s, multi_class="multinomial"
        )
        clf.fit(X_tr, y_tr)
        y_pred = clf.predict(Xr_te_s)

        # metrics + per-seed report
        acc = accuracy_score(yr_te, y_pred)
        f1m = f1_score(yr_te, y_pred, average="macro")
        labels = sorted(yr_te.unique())
        cm = confusion_matrix(yr_te, y_pred, labels=labels)

        with open(
            os.path.join(result_path, f"clf_report_both_real_seed{s}.txt"), "w"
        ) as f:
            f.write(f"Test Accuracy: {acc:.4f}\n")
            f.write(f"Macro-F1: {f1m:.4f}\n\n")
            f.write("Classification Report:\n")
            f.write(classification_report(yr_te, y_pred, digits=4))
            f.write("\nConfusion Matrix (rows=true, cols=pred):\n")
            f.write(pd.DataFrame(cm, index=labels, columns=labels).to_string())
            if evr is not None:
                f.write(f"\n\nPCA EVR (first {len(evr)}): {np.round(evr,4).tolist()}\n")
                f.write(f"Total EVR: {float(evr.sum()):.4f}\n")

        accs.append(acc)
        f1s.append(f1m)

    summarize(accs, f1s, "both_real")


# def correlation_analysis_real_generated(df, features, result_path, source_col="Source"):
#     """
#     Compute and visualize correlation matrices for real and generated features, and compare them.
#     Saves heatmaps and difference matrix to result_path.
#     """
#     import os

#     import matplotlib.pyplot as plt
#     import seaborn as sns

#     if not os.path.exists(result_path):
#         os.makedirs(result_path)
#     # Use PCA to select 10 most important features (by absolute loading in PC1)
#     from scipy.stats import pearsonr
#     from sklearn.decomposition import PCA

#     # pretty_map = {
#     #     "AreaOccupied_AreaOccupied_Cells": "Area Occupied (Cells)",
#     #     "StDev_Nuclei_RadialDistribution_MeanFrac_OrigGray_3of4": "Std Dev — Nuclei Radial Distribution: Mean Fraction",
#     #     "StDev_Cells_AreaShape_Zernike_0_0": "Std Dev — Cells Area Shape: Zernike",
#     #     "StDev_Cells_AreaShape_FormFactor": "Std Dev — Cells Area Shape: Form Factor",
#     #     "StDev_Nuclei_Intensity_MassDisplacement_OrigGray": "Std Dev — Nuclei Intensity: Mass Displacement",
#     #     "Mean_Cells_AreaShape_Area": "Mean — Cells Area",
#     #     "Median_Cells_AreaShape_Zernike_2_0": "Median — Cells Area Shape: Zernike",
#     #     "Mean_Nuclei_RadialDistribution_RadialCV_OrigGray_1of4": "Mean — Nuclei Radial Distribution: Radial CV",
#     #     "StDev_Cells_Granularity_8_OrigGray": "Std Dev — Cells Granularity",
#     #     "Mean_Cells_Neighbors_AngleBetweenNeighbors_5": "Mean — Cells Neighbors: Angle Between (k=5)",
#     # }
#     pretty_map = {
#         "AreaOccupied_AreaOccupied_Cells": "Area Occupied (Cells)",
#         "StDev_Nuclei_RadialDistribution_MeanFrac_OrigGray_3of4": "Nuc Radial MeanFrac (σ)",
#         "StDev_Cells_AreaShape_Zernike_0_0": "Cells Zernike (σ)",
#         "StDev_Cells_AreaShape_FormFactor": "Cells FormFactor (σ)",
#         "StDev_Nuclei_Intensity_MassDisplacement_OrigGray": "Nuc MassDisp (σ)",
#         "Mean_Cells_AreaShape_Area": "Cells Area (μ)",
#         "Median_Cells_AreaShape_Zernike_2_0": "Cells Zernike (Med)",
#         "Mean_Nuclei_RadialDistribution_RadialCV_OrigGray_1of4": "Nuc Radial CV (μ)",
#         "StDev_Cells_Granularity_8_OrigGray": "Granularity (σ)",
#         "Mean_Cells_Neighbors_AngleBetweenNeighbors_5": "Neighbors Angle k=5 (μ)",
#     }

#     n_top = 10
#     # pca = PCA(n_components=1)
#     pca = PCA(n_components=1, svd_solver="full", random_state=0)

#     X = df[features].values
#     pca.fit(X)
#     pc1_loadings = np.abs(pca.components_[0])
#     top_idx = np.argsort(pc1_loadings)[-n_top:][::-1]
#     top_features = [features[i] for i in top_idx]
#     print(f"[Correlation] Top {n_top} features by PC1 loading: {top_features}")
#     real = df[df[source_col] == "real"][top_features]
#     gen = df[df[source_col] == "generated"][top_features]
#     corr_real = real.corr()
#     corr_gen = gen.corr()
#     diff = corr_real - corr_gen
#     # Plot real
#     plt.figure(figsize=(10, 8))
#     # sns.heatmap(corr_real, cmap="coolwarm", center=0, annot=True, fmt=".2f")
#     sns.heatmap(
#         corr_real,
#         cmap="coolwarm",
#         center=0,
#         annot=True,
#         fmt=".2f",
#         xticklabels=[pretty_map.get(c, c) for c in corr_real.columns],
#         yticklabels=[pretty_map.get(r, r) for r in corr_real.index],
#     )
#     plt.title(f"Correlation Matrix: Real (Top {n_top} PCA features)")
#     plt.tight_layout()
#     plt.savefig(os.path.join(result_path, f"corr_matrix_real_top{n_top}.png"), dpi=200)
#     plt.close()
#     # Plot generated
#     plt.figure(figsize=(10, 8))
#     # sns.heatmap(corr_gen, cmap="coolwarm", center=0, annot=True, fmt=".2f")
#     sns.heatmap(
#         corr_gen,
#         cmap="coolwarm",
#         center=0,
#         annot=True,
#         fmt=".2f",
#         xticklabels=[pretty_map.get(c, c) for c in corr_real.columns],
#         yticklabels=[pretty_map.get(r, r) for r in corr_real.index],
#     )
#     plt.title(f"Correlation Matrix: Generated (Top {n_top} PCA features)")
#     plt.tight_layout()
#     plt.savefig(
#         os.path.join(result_path, f"corr_matrix_generated_top{n_top}.png"), dpi=200
#     )
#     plt.close()
#     # Plot difference
#     plt.figure(figsize=(10, 8))
#     # sns.heatmap(diff, cmap="bwr", center=0, annot=True, fmt=".2f")
#     sns.heatmap(
#         diff,
#         cmap="bwr",
#         center=0,
#         annot=True,
#         fmt=".2f",
#         xticklabels=[pretty_map.get(c, c) for c in corr_real.columns],
#         yticklabels=[pretty_map.get(r, r) for r in corr_real.index],
#     )
#     plt.title(
#         f"Correlation Matrix Difference (Real - Generated, Top {n_top} PCA features)"
#     )
#     plt.tight_layout()
#     plt.savefig(os.path.join(result_path, f"corr_matrix_diff_top{n_top}.png"), dpi=200)
#     plt.close()

def correlation_analysis_real_generated(
    df, features, result_path, source_col="Source",
    n_top=10, tick_rotation=45, tick_fontsize=11, bottom_margin=0.25
):
    """
    Side-by-side correlation matrices with a shared horizontal colorbar and a right-side legend.
    Uses constrained_layout to keep heatmaps, colorbar, and legend perfectly aligned.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.decomposition import PCA
    from math import ceil
    from textwrap import shorten

    if not os.path.exists(result_path):
        os.makedirs(result_path)

    # Short names (edit as needed)
    pretty_map = {
        "AreaOccupied_AreaOccupied_Cells": "Area Occupied (Cells)",
        "StDev_Nuclei_RadialDistribution_MeanFrac_OrigGray_3of4": "Nuc Radial MeanFrac ($\\sigma$)",
        "StDev_Cells_AreaShape_Zernike_0_0": "Cells Zernike ($\\sigma$)",
        "StDev_Cells_AreaShape_FormFactor": "Cells FormFactor ($\\sigma$)",
        "StDev_Nuclei_Intensity_MassDisplacement_OrigGray": "Nuc MassDisp ($\\sigma$)",
        "Mean_Cells_AreaShape_Area": "Cells Area (μ)",
        "Median_Cells_AreaShape_Zernike_2_0": "Cells Zernike (Med)",
        "Mean_Nuclei_RadialDistribution_RadialCV_OrigGray_1of4": "Nuc Radial CV ($\\mu$)",
        "StDev_Cells_Granularity_8_OrigGray": "Granularity ($\\sigma$)",
        "Mean_Cells_Neighbors_AngleBetweenNeighbors_5": "Neighbors Angle k=5 ($\\mu$)",
    }

    # --- PCA: pick top-k by |PC1 loading| ---
    pca = PCA(n_components=1, svd_solver="full", random_state=0)
    X = df[features].values
    pca.fit(X)
    pc1_loadings = np.abs(pca.components_[0])
    top_idx = np.argsort(pc1_loadings)[-n_top:][::-1]
    top_features = [features[i] for i in top_idx]
    print(f"[Correlation] Top {n_top} features by PC1 loading: {top_features}")

    # Splits & correlations
    real = df[df[source_col] == "real"][top_features]
    gen  = df[df[source_col] == "generated"][top_features]
    corr_real = real.corr()
    corr_gen  = gen.corr()

    # Common order
    ordered_cols = list(corr_real.columns)
    corr_real = corr_real.reindex(index=ordered_cols, columns=ordered_cols)
    corr_gen  = corr_gen.reindex(index=ordered_cols, columns=ordered_cols)
    diff      = corr_real - corr_gen

    # Labels & scale
    k = len(ordered_cols)
    idx_labels = list(range(1, k + 1))
    pretty_labels = [pretty_map.get(c, c) for c in ordered_cols]

    common_max = float(np.ceil(np.nanmax(np.abs(
        np.concatenate([corr_real.values.ravel(), corr_gen.values.ravel()])
    )) * 10) / 10)
    diff_max = float(np.ceil(np.nanmax(np.abs(diff.values)) * 10) / 10)

    # =========================
    # COMBINED FIGURE
    # =========================
    fig = plt.figure(figsize=(16, 8.75), constrained_layout=True)
    fig.set_constrained_layout_pads(h_pad=0.01, hspace=0.01, wspace=0.08)


    # One row, three columns: real | gen | side (cbar + legend)
    outer = fig.add_gridspec(
        nrows=1, ncols=3, width_ratios=[1.0, 1.0, 0.85]
    )
    ax_real = fig.add_subplot(outer[0, 0])
    ax_gen  = fig.add_subplot(outer[0, 1])

    # Side column is vertically split into colorbar (top) and legend (bottom)
    spacer = 1.5        # try 0.2–0.6 depending on figure size
    side = outer[0, 2].subgridspec(nrows=3, ncols=1,
                                height_ratios=[spacer, 0.5, 5])  # [spacer, cbar, legend]

    ax_spacer = fig.add_subplot(side[0, 0]); ax_spacer.axis("off")
    cax       = fig.add_subplot(side[1, 0])
    ax_leg    = fig.add_subplot(side[2, 0]); ax_leg.axis("off")
    
    # Heatmap helper
    def _heatmap(ax, mat):
        hm = sns.heatmap(
            mat.values,
            cmap="coolwarm", vmin=-common_max, vmax=common_max, center=0.0,
            annot=True, fmt=".2f", annot_kws={"size": 12},
            xticklabels=idx_labels, yticklabels=idx_labels,
            cbar=False, square=True, linewidths=0.2, linecolor="white", ax=ax
        )
        ax.tick_params(axis="both", labelsize=tick_fontsize)
        ax.set_xlabel("Feature index", fontsize=tick_fontsize+7)
        ax.set_ylabel("Feature index", fontsize=tick_fontsize+7)
        return hm

    hm_real = _heatmap(ax_real, corr_real)
    _       = _heatmap(ax_gen,  corr_gen)

    # Shared horizontal colorbar (stays within the side column height)
    mappable = hm_real.collections[0]
    cb = fig.colorbar(mappable, cax=cax, orientation="horizontal")
    # cb.set_label("Pearson correlation", fontsize=tick_fontsize+1)
    cb.ax.tick_params(labelsize=tick_fontsize)

    # Legend in side column, larger font, multi-column if needed
    n_cols = 1 # if k <= 10 else (2 if k <= 18 else 3)
    rows = ceil(k / n_cols)
    blocks = []
    for col in range(n_cols):
        start = col * rows
        end = min((col + 1) * rows, k)
        lines = [f"{i+1:>2}. {shorten(pretty_labels[i], width=55, placeholder='…')}"
                 for i in range(start, end)]
        blocks.append("\n".join(lines))
    legend_text = "   |   ".join(blocks)
    ax_leg.text(
        0.0, 1.0, "\nFeature Index\n\n" + legend_text,
        ha="left", va="top", fontsize=17, family="monospace",
        transform=ax_leg.transAxes
    )

    # Single big title
    # fig.suptitle(
    #     f"Correlation Matrices — Real vs Generated (Top {n_top} PCA Features)",
    #     fontsize=20
    # )
    fig.text(
        0.5, 0.85,
        f"Correlation Matrices: Real vs Generated (Top {n_top} PCA Features)",
        ha="center", va="top", fontsize=20
    )

    side_by_side_name = os.path.join(result_path, f"corr_real_vs_generated_top{n_top}.png")
    fig.savefig(side_by_side_name, dpi=300)
    plt.close(fig)

    # =========================
    # DIFF PLOT
    # =========================
    fig_d, ax_d = plt.subplots(figsize=(8.8, 8.2))
    sns.heatmap(
        diff.values,
        cmap="bwr", vmin=-diff_max, vmax=diff_max, center=0.0,
        annot=True, fmt=".2f", annot_kws={"size": 8},
        xticklabels=idx_labels, yticklabels=idx_labels,
        cbar_kws=dict(shrink=0.85),
        square=True, linewidths=0.2, linecolor="white", ax=ax_d
    )
    ax_d.set_title("Correlation Difference (Real − Generated)", fontsize=14)
    ax_d.set_xlabel("Feature index", fontsize=tick_fontsize+1)
    ax_d.set_ylabel("Feature index", fontsize=tick_fontsize+1)
    ax_d.tick_params(axis="both", labelsize=tick_fontsize)
    fig_d.tight_layout()
    fig_d.savefig(os.path.join(result_path, f"corr_matrix_diff_top{n_top}.png"), dpi=300)
    plt.close(fig_d)

    print(f"[Saved] {side_by_side_name}")
    print(f"[Saved] {os.path.join(result_path, f'corr_matrix_diff_top{n_top}.png')}")




# Main block for running the analysis pipeline
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Analyze CellProfiler features for real and generated cells."
    )
    parser.add_argument(
        "--cp_dir",
        type=str,
        default="/mnt/pvc/CP_outputs",
        help="Path to CP_outputs directory",
    )
    parser.add_argument(
        "--perturbations",
        nargs="+",
        default=["1108", "1124", "1137", "1138"],
        help="List of perturbation directories",
    )
    parser.add_argument(
        "--results_dir",
        type=str,
        default="results_cp_analysis",
        help="Directory to save results",
    )
    args = parser.parse_args()

    # 1. Load data
    df_all, all_cols = load_cp_features(
        args.cp_dir, args.perturbations, groups=("real", "generated")
    )

    # 2. Select feature columns (exclude metadata and non-numeric columns)
    metadata_cols = [
        c
        for c in all_cols
        if any(
            x in c.lower() for x in ["metadata", "filename", "pathname", "imagenumber"]
        )
    ]
    # Only keep columns that are numeric
    numeric_cols = df_all.select_dtypes(include=["number"]).columns.tolist()
    feature_cols = [
        c
        for c in all_cols
        if c not in metadata_cols
        and c not in ["Source", "Perturbation"]
        and c in numeric_cols
    ]

    # 3. Feature selection (low variance, outlier removal)
    print(f"Initial number of features: {len(feature_cols)}")
    selected_features = select_features(df_all, feature_cols)
    print(f"After low-variance/outlier removal: {len(selected_features)}")

    # 3b. Remove highly correlated features
    reduced_features = remove_highly_correlated_features(
        df_all, selected_features, corr_thresh=0.7
    )
    print(f"After removing highly correlated features: {len(reduced_features)}")

    # 4. PCA/t-SNE visualization (for visualization, still use global scaling for fair comparison)
    df_all_scaled = preprocess_features(df_all, reduced_features)
    plot_pca(
        df_all_scaled,
        reduced_features,
        args.results_dir,
        color_by="Perturbation",
        style_by="Source",
        prefix="all",
    )

    # 5. Train and evaluate linear classifier (no leakage)
    train_evaluate_linear_classifier(
        df_all,
        reduced_features,
        args.results_dir,
        label_col="Perturbation",
        source_col="Source",
    )

    # 6. Correlation analysis (use scaled for fair comparison)
    correlation_analysis_real_generated(
        df_all_scaled, reduced_features, args.results_dir, source_col="Source"
    )

    print(f"Analysis complete. Results saved to {args.results_dir}")
