import os
import random
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule
from sc_perturb.openphenom import OpenPhenomEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tqdm import tqdm

warnings.filterwarnings("ignore")


def get_perturbation_cell_type_distribution_custom(datamodule):
    """
    Custom function to get the joint distribution of perturbation IDs and cell types in the training set.
    This handles the case where the metadata might have 'cell_type' (string) instead of 'cell_type_id' (int).

    Returns:
        Dictionary mapping (perturbation_id, cell_type_id) to count
    """
    train_metadata = datamodule.metadata[datamodule.metadata["train_index"]]

    # Define cell type mapping
    cell_type_to_label = {
        "HEPG2": 0,
        "HUVEC": 1,
        "RPE": 2,
        "U2OS": 3,
    }

    # Check if we have cell_type_id or need to map from cell_type
    if "cell_type_id" in train_metadata.columns:
        grouped = train_metadata.groupby(["sirna_id", "cell_type_id"]).size()
        distribution = {
            (int(idx[0]), int(idx[1])): count for idx, count in grouped.items()
        }
    elif "cell_type" in train_metadata.columns:
        # Create cell_type_id from cell_type strings
        train_metadata = train_metadata.copy()
        train_metadata["cell_type_id"] = train_metadata["cell_type"].map(
            cell_type_to_label
        )
        grouped = train_metadata.groupby(["sirna_id", "cell_type_id"]).size()
        distribution = {
            (int(idx[0]), int(idx[1])): count for idx, count in grouped.items()
        }
    else:
        raise ValueError(
            "Neither 'cell_type_id' nor 'cell_type' column found in metadata"
        )

    return distribution


def get_top_perturbations_for_cell_type(datamodule, cell_type_id=0, top_k=10):
    """
    Get the top K perturbation IDs with the most samples for a specific cell type.

    Args:
        datamodule: CellDataModule instance
        cell_type_id: Cell type ID to filter for (default: 0)
        top_k: Number of top perturbations to return (default: 10)

    Returns:
        List of perturbation IDs sorted by sample count (descending)
    """
    # Get the joint distribution using our custom function
    distribution = get_perturbation_cell_type_distribution_custom(datamodule)

    # Filter for the specific cell type and get counts
    cell_type_counts = {}

    for (pert_id, ct_id), count in distribution.items():
        if ct_id == cell_type_id:
            cell_type_counts[pert_id] = count

    # # Sort by count and get top K
    # sorted_perturbations = sorted(
    #     cell_type_counts.items(), key=lambda x: x[1], reverse=True
    # )
    # top_perturbations = [pert_id for pert_id, count in sorted_perturbations[:top_k]]
    # # Sort by count in descending order
    sorted_perturbations = sorted(
        cell_type_counts.items(), key=lambda x: x[1], reverse=True
    )

    # Filter out the excluded class (1116) and take the top_k
    top_perturbations = [
        pert_id for pert_id, count in sorted_perturbations if pert_id != 1116
    ][:top_k]

    print(f"Top {top_k} perturbations for cell_type_id {cell_type_id}:")
    for i, (pert_id, count) in enumerate(sorted_perturbations[:top_k]):
        print(f"  {i+1}. Perturbation {pert_id}: {count} samples")

    return top_perturbations


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.
    """
    import glob

    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def extract_features_batch(images_tensor, encoder, device, batch_size=64):
    """
    Extract features from images using the encoder in batches.
    """
    features_list = []
    if images_tensor is not None and images_tensor.size(0) > 0:
        with torch.no_grad():
            for i_batch in range(0, images_tensor.size(0), batch_size):
                batch_tensor = images_tensor[i_batch : i_batch + batch_size].to(device)
                features_batch = encoder(batch_tensor).cpu().numpy()
                features_list.append(features_batch)
        return np.vstack(features_list) if features_list else np.array([])
    else:
        return np.array([])


def load_and_extract_features_single_cell_type(
    datamodule,
    generated_path,
    perturbation_ids,
    cell_type_id,
    encoder,
    device,
    max_samples_per_pert=500,
    batch_size=64,
):
    """
    Load real and generated data and extract features for specified perturbations and a single cell type.

    Returns:
        real_features, real_pert_labels,
        generated_features, generated_pert_labels
    """
    all_real_features = []
    all_real_pert_labels = []

    all_generated_features = []
    all_generated_pert_labels = []

    for pert_id in tqdm(perturbation_ids, desc="Processing perturbations"):
        print(f"\nProcessing perturbation ID: {pert_id}")

        # Load real data for specific cell type
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id,
            cell_type_id=cell_type_id,
        )

        if real_filtered_dataset is not None and len(real_filtered_dataset) > 0:
            real_images = [
                real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
            ]

            # Limit samples if needed
            if len(real_images) > max_samples_per_pert:
                indices = random.sample(range(len(real_images)), max_samples_per_pert)
                real_images = [real_images[i] for i in indices]

            real_images_tensor = torch.stack(real_images)
            real_features = extract_features_batch(
                real_images_tensor, encoder, device, batch_size
            )

            if real_features.size > 0:
                all_real_features.append(real_features)
                all_real_pert_labels.extend([pert_id] * len(real_features))
                print(f"  Real: {len(real_features)} samples")

        # Load generated data for specific cell type
        gen_files = find_generated_files_by_perturbation_and_celltype(
            generated_path, pert_id, cell_type_id
        )

        if gen_files:
            # Sample if too many files
            if len(gen_files) > max_samples_per_pert:
                sampled_gen_files = random.sample(gen_files, max_samples_per_pert)
            else:
                sampled_gen_files = gen_files

            temp_generated_images = []

            for file_path in sampled_gen_files:
                try:
                    img = np.load(file_path)
                    temp_generated_images.append(torch.from_numpy(img).float())
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")

            if temp_generated_images:
                generated_images_tensor = torch.stack(temp_generated_images)
                generated_features = extract_features_batch(
                    generated_images_tensor, encoder, device, batch_size
                )

                if generated_features.size > 0:
                    all_generated_features.append(generated_features)
                    all_generated_pert_labels.extend(
                        [pert_id] * len(generated_features)
                    )
                    print(f"  Generated: {len(generated_features)} samples")

    # Combine all features
    real_features = np.vstack(all_real_features) if all_real_features else np.array([])
    generated_features = (
        np.vstack(all_generated_features) if all_generated_features else np.array([])
    )

    return (
        real_features,
        all_real_pert_labels,
        generated_features,
        all_generated_pert_labels,
    )


def train_and_evaluate_model(X_train, y_train, X_test, y_test, model_name, task_name):
    """
    Train a logistic regression model and evaluate it.
    """
    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Train model
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X_train_scaled, y_train)

    # Predict
    y_pred = model.predict(X_test_scaled)
    accuracy = accuracy_score(y_test, y_pred)

    print(f"\n{model_name} - {task_name}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Number of classes: {len(np.unique(y_train))}")
    print(f"Train samples: {len(X_train)}, Test samples: {len(X_test)}")

    # Cross-validation on training set
    cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5)
    print(f"CV Accuracy: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")

    # Calculate per-class accuracies
    unique_classes = np.unique(y_test)
    per_class_accuracies = {}
    per_class_counts = {}

    print(f"\nPer-Class Accuracies:")
    for class_label in unique_classes:
        class_mask = y_test == class_label
        class_predictions = y_pred[class_mask]
        class_true = y_test[class_mask]

        if len(class_true) > 0:
            class_accuracy = accuracy_score(class_true, class_predictions)
            per_class_accuracies[class_label] = class_accuracy
            per_class_counts[class_label] = len(class_true)
            print(
                f"  Class {class_label}: {class_accuracy:.4f} ({len(class_true)} samples)"
            )
        else:
            per_class_accuracies[class_label] = 0.0
            per_class_counts[class_label] = 0

    return {
        "model_name": model_name,
        "task": task_name,
        "accuracy": accuracy,
        "cv_mean": cv_scores.mean(),
        "cv_std": cv_scores.std(),
        "n_train": len(X_train),
        "n_test": len(X_test),
        "n_classes": len(np.unique(y_train)),
        "per_class_accuracies": per_class_accuracies,
        "per_class_counts": per_class_counts,
        "model": model,
        "scaler": scaler,
        "y_true": y_test,
        "y_pred": y_pred,
    }


def plot_confusion_matrix(y_true, y_pred, labels, title, save_path):
    """
    Plot and save confusion matrix.
    """
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels
    )
    plt.title(title, fontsize=16)
    plt.xlabel("Predicted", fontsize=14)
    plt.ylabel("True", fontsize=14)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()


def check_data_leakage_and_similarity(real_features, generated_features, output_dir):
    """
    Check for potential data leakage and similarity between real and generated features.
    """
    print("\n" + "=" * 60)
    print("CHECKING FOR DATA LEAKAGE AND SIMILARITY")
    print("=" * 60)

    if len(real_features) == 0 or len(generated_features) == 0:
        print("Cannot perform similarity analysis - missing real or generated features")
        return

    # 1. Compute pairwise distances between real and generated features
    print("Computing pairwise distances...")

    # Sample if too many features (for computational efficiency)
    max_samples = 1000
    if len(real_features) > max_samples:
        real_sample_idx = np.random.choice(
            len(real_features), max_samples, replace=False
        )
        real_sample = real_features[real_sample_idx]
    else:
        real_sample = real_features

    if len(generated_features) > max_samples:
        gen_sample_idx = np.random.choice(
            len(generated_features), max_samples, replace=False
        )
        gen_sample = generated_features[gen_sample_idx]
    else:
        gen_sample = generated_features

    # Compute cosine similarities
    print(
        f"Computing similarities between {len(real_sample)} real and {len(gen_sample)} generated samples..."
    )
    similarities = cosine_similarity(real_sample, gen_sample)

    # Compute Euclidean distances
    distances = euclidean_distances(real_sample, gen_sample)

    # 2. Analyze similarity statistics
    max_similarities = np.max(
        similarities, axis=1
    )  # Max similarity for each real sample
    min_distances = np.min(distances, axis=1)  # Min distance for each real sample

    print(f"\nSimilarity Analysis:")
    print(f"  Cosine Similarity (Real vs Generated):")
    print(f"    Mean max similarity: {np.mean(max_similarities):.4f}")
    print(f"    Std max similarity:  {np.std(max_similarities):.4f}")
    print(f"    Max similarity found: {np.max(max_similarities):.4f}")
    print(
        f"    Samples with >0.95 similarity: {np.sum(max_similarities > 0.95)}/{len(max_similarities)}"
    )
    print(
        f"    Samples with >0.99 similarity: {np.sum(max_similarities > 0.99)}/{len(max_similarities)}"
    )

    print(f"\n  Euclidean Distance (Real vs Generated):")
    print(f"    Mean min distance: {np.mean(min_distances):.4f}")
    print(f"    Std min distance:  {np.std(min_distances):.4f}")
    print(f"    Min distance found: {np.min(min_distances):.4f}")

    # 3. Check for exact duplicates (or near-exact)
    very_similar_threshold = 0.999
    very_similar_pairs = np.sum(similarities > very_similar_threshold)
    print(f"\n  Very similar pairs (>0.999 cosine similarity): {very_similar_pairs}")

    if very_similar_pairs > 0:
        print(
            f"  ⚠️  WARNING: Found {very_similar_pairs} very similar real-generated pairs!"
        )

    # 4. Compare with real-real and generated-generated similarities for baseline
    print(f"\nBaseline Comparisons:")

    # Real-real similarities
    if len(real_sample) > 1:
        real_real_sim = cosine_similarity(real_sample[: min(500, len(real_sample))])
        # Remove diagonal (self-similarities)
        real_real_sim_no_diag = real_real_sim[np.triu_indices_from(real_real_sim, k=1)]
        print(f"  Real-Real similarity mean: {np.mean(real_real_sim_no_diag):.4f}")
        print(f"  Real-Real similarity max:  {np.max(real_real_sim_no_diag):.4f}")

    # Generated-generated similarities
    if len(gen_sample) > 1:
        gen_gen_sim = cosine_similarity(gen_sample[: min(500, len(gen_sample))])
        gen_gen_sim_no_diag = gen_gen_sim[np.triu_indices_from(gen_gen_sim, k=1)]
        print(f"  Gen-Gen similarity mean:   {np.mean(gen_gen_sim_no_diag):.4f}")
        print(f"  Gen-Gen similarity max:    {np.max(gen_gen_sim_no_diag):.4f}")

    # 5. Create visualizations
    plt.figure(figsize=(15, 5))

    # Plot 1: Distribution of max similarities
    plt.subplot(1, 3, 1)
    plt.hist(max_similarities, bins=50, alpha=0.7, color="skyblue", edgecolor="black")
    plt.axvline(
        np.mean(max_similarities),
        color="red",
        linestyle="--",
        label=f"Mean: {np.mean(max_similarities):.3f}",
    )
    plt.axvline(0.95, color="orange", linestyle="--", label="0.95 threshold")
    plt.xlabel("Max Cosine Similarity")
    plt.ylabel("Count")
    plt.title("Distribution of Max Real-Generated Similarities")
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot 2: Distribution of min distances
    plt.subplot(1, 3, 2)
    plt.hist(min_distances, bins=50, alpha=0.7, color="lightcoral", edgecolor="black")
    plt.axvline(
        np.mean(min_distances),
        color="red",
        linestyle="--",
        label=f"Mean: {np.mean(min_distances):.3f}",
    )
    plt.xlabel("Min Euclidean Distance")
    plt.ylabel("Count")
    plt.title("Distribution of Min Real-Generated Distances")
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot 3: Similarity matrix heatmap (subset)
    plt.subplot(1, 3, 3)
    subset_size = min(50, len(real_sample), len(gen_sample))
    sim_subset = similarities[:subset_size, :subset_size]
    sns.heatmap(
        sim_subset, cmap="viridis", center=0, xticklabels=False, yticklabels=False
    )
    plt.title(f"Similarity Matrix\n(Real vs Generated, {subset_size}x{subset_size})")
    plt.xlabel("Generated Samples")
    plt.ylabel("Real Samples")

    plt.tight_layout()
    plt.savefig(f"{output_dir}/similarity_analysis.png", dpi=300, bbox_inches="tight")
    plt.close()

    # 6. Save detailed similarity statistics
    similarity_stats = {
        "max_similarity_mean": np.mean(max_similarities),
        "max_similarity_std": np.std(max_similarities),
        "max_similarity_max": np.max(max_similarities),
        "samples_above_095": np.sum(max_similarities > 0.95),
        "samples_above_099": np.sum(max_similarities > 0.99),
        "very_similar_pairs": very_similar_pairs,
        "min_distance_mean": np.mean(min_distances),
        "min_distance_min": np.min(min_distances),
        "total_real_samples": len(real_sample),
        "total_generated_samples": len(gen_sample),
    }

    # Add baseline comparisons if available
    if len(real_sample) > 1:
        similarity_stats["real_real_similarity_mean"] = np.mean(real_real_sim_no_diag)
        similarity_stats["real_real_similarity_max"] = np.max(real_real_sim_no_diag)

    if len(gen_sample) > 1:
        similarity_stats["gen_gen_similarity_mean"] = np.mean(gen_gen_sim_no_diag)
        similarity_stats["gen_gen_similarity_max"] = np.max(gen_gen_sim_no_diag)

    # Save to file
    with open(f"{output_dir}/similarity_statistics.txt", "w") as f:
        f.write("SIMILARITY ANALYSIS RESULTS\n")
        f.write("=" * 50 + "\n\n")
        for key, value in similarity_stats.items():
            f.write(f"{key}: {value}\n")

    # 7. Provide recommendations
    print(f"\n{'='*60}")
    print("RECOMMENDATIONS:")
    print("=" * 60)

    if np.mean(max_similarities) > 0.9:
        print("⚠️  HIGH SIMILARITY WARNING:")
        print("   Generated features are very similar to real features.")
        print(
            "   This could explain the high accuracy and suggests potential overfitting."
        )

    if very_similar_pairs > len(real_sample) * 0.1:  # More than 10% very similar
        print("⚠️  POTENTIAL DATA LEAKAGE WARNING:")
        print("   Found many very similar real-generated pairs.")
        print("   Check if generated samples are too close to training data.")

    if np.sum(max_similarities > 0.99) > 0:
        print("⚠️  NEAR-DUPLICATE WARNING:")
        print("   Found near-duplicate real-generated pairs.")
        print("   Generated model might be memorizing training data.")

    print("\n📊 Similarity analysis plots saved to similarity_analysis.png")
    print("📊 Detailed statistics saved to similarity_statistics.txt")

    return similarity_stats


def analyze_train_test_overlap(X_train, X_test, model_name, output_dir):
    """
    Check if there's any overlap between training and test features.
    """
    print(f"\n🔍 Checking train/test overlap for {model_name}...")

    if len(X_train) == 0 or len(X_test) == 0:
        print(f"  Skipping - no data for {model_name}")
        return

    # Compute similarities between train and test
    similarities = cosine_similarity(X_test, X_train)
    max_similarities = np.max(similarities, axis=1)

    # Check for very high similarities (potential leakage)
    high_sim_count = np.sum(max_similarities > 0.99)
    very_high_sim_count = np.sum(max_similarities > 0.999)

    print(
        f"  Test samples with >0.99 similarity to train: {high_sim_count}/{len(X_test)}"
    )
    print(
        f"  Test samples with >0.999 similarity to train: {very_high_sim_count}/{len(X_test)}"
    )
    print(f"  Mean max similarity: {np.mean(max_similarities):.4f}")

    if very_high_sim_count > 0:
        print(
            f"  ⚠️  WARNING: {very_high_sim_count} test samples are nearly identical to training samples!"
        )

    # Save overlap analysis
    overlap_info = {
        "model_name": model_name,
        "high_similarity_count": high_sim_count,
        "very_high_similarity_count": very_high_sim_count,
        "mean_max_similarity": np.mean(max_similarities),
        "max_similarity": np.max(max_similarities),
        "n_train": len(X_train),
        "n_test": len(X_test),
    }

    return overlap_info


def create_per_class_accuracy_report(results, pert_encoder, output_dir):
    """
    Create a comprehensive per-class accuracy report and visualization.
    """
    print(f"\n{'='*80}")
    print("PER-CLASS ACCURACY ANALYSIS")
    print("=" * 80)

    # Collect all per-class data
    per_class_data = []

    for result in results:
        model_name = result["model_name"]
        per_class_accs = result.get("per_class_accuracies", {})
        per_class_counts = result.get("per_class_counts", {})

        for class_idx, accuracy in per_class_accs.items():
            # Convert class index back to perturbation ID
            pert_id = pert_encoder.classes_[class_idx]
            count = per_class_counts.get(class_idx, 0)

            per_class_data.append(
                {
                    "Model": model_name,
                    "Perturbation_ID": pert_id,
                    "Class_Index": class_idx,
                    "Accuracy": accuracy,
                    "Test_Samples": count,
                }
            )

    # Create DataFrame
    per_class_df = pd.DataFrame(per_class_data)

    if len(per_class_df) == 0:
        print("No per-class data found.")
        return

    # Save detailed per-class results
    per_class_df.to_csv(f"{output_dir}/per_class_accuracies.csv", index=False)

    # Create pivot table for easier viewing
    pivot_df = per_class_df.pivot(
        index="Perturbation_ID", columns="Model", values="Accuracy"
    )
    pivot_df = pivot_df.fillna(0.0)  # Fill NaN with 0 for missing combinations

    # Add average row to the pivot table
    avg_row = pivot_df.mean(
        axis=0
    )  # Calculate mean across perturbations for each model
    pivot_df_with_avg = pivot_df.copy()
    pivot_df_with_avg.loc["AVERAGE"] = avg_row

    # Save pivot table with averages
    pivot_df_with_avg.to_csv(f"{output_dir}/per_class_accuracies_pivot.csv")

    # Print summary statistics
    print(f"\nPer-Class Accuracy Summary:")
    print(f"Number of perturbations analyzed: {len(pert_encoder.classes_)}")
    print(f"Perturbation IDs: {list(pert_encoder.classes_)}")

    # Calculate statistics for each model
    model_stats = (
        per_class_df.groupby("Model")
        .agg({"Accuracy": ["mean", "std", "min", "max"], "Test_Samples": "sum"})
        .round(4)
    )

    print(f"\nModel Performance Statistics:")
    print(model_stats)

    # Create visualizations
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))

    # Plot 1: Heatmap of per-class accuracies
    ax1 = axes[0, 0]
    sns.heatmap(
        pivot_df.T,
        annot=True,
        fmt=".3f",
        cmap="RdYlBu_r",
        ax=ax1,
        cbar_kws={"label": "Accuracy"},
    )
    ax1.set_title("Per-Class Accuracies Heatmap", fontsize=14)
    ax1.set_xlabel("Perturbation ID", fontsize=12)
    ax1.set_ylabel("Model", fontsize=12)
    ax1.tick_params(axis="x", rotation=45)

    # Plot 2: Bar plot comparing models for each perturbation
    ax2 = axes[0, 1]
    pivot_df.plot(kind="bar", ax=ax2, width=0.8)
    ax2.set_title("Per-Class Accuracies by Perturbation", fontsize=14)
    ax2.set_xlabel("Perturbation ID", fontsize=12)
    ax2.set_ylabel("Accuracy", fontsize=12)
    ax2.legend(title="Model", bbox_to_anchor=(1.05, 1), loc="upper left")
    ax2.set_ylim(0, 1.0)
    ax2.grid(True, alpha=0.3)

    # Plot 3: Box plot of accuracies per model
    ax3 = axes[1, 0]
    per_class_df.boxplot(column="Accuracy", by="Model", ax=ax3)
    ax3.set_title("Distribution of Per-Class Accuracies by Model", fontsize=14)
    ax3.set_xlabel("Model", fontsize=12)
    ax3.set_ylabel("Per-Class Accuracy", fontsize=12)
    ax3.tick_params(axis="x", rotation=45)
    plt.setp(ax3.xaxis.get_majorticklabels(), ha="right")

    # Plot 4: Scatter plot - Test samples vs Accuracy
    ax4 = axes[1, 1]
    models = per_class_df["Model"].unique()
    colors = plt.cm.Set1(np.linspace(0, 1, len(models)))

    for i, model in enumerate(models):
        model_data = per_class_df[per_class_df["Model"] == model]
        ax4.scatter(
            model_data["Test_Samples"],
            model_data["Accuracy"],
            alpha=0.7,
            label=model,
            color=colors[i],
            s=60,
        )

    ax4.set_xlabel("Number of Test Samples", fontsize=12)
    ax4.set_ylabel("Per-Class Accuracy", fontsize=12)
    ax4.set_title("Per-Class Accuracy vs Number of Test Samples", fontsize=14)
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_ylim(0, 1.0)

    plt.tight_layout()
    plt.savefig(
        f"{output_dir}/per_class_accuracy_analysis.png", dpi=300, bbox_inches="tight"
    )
    plt.close()

    # Find best and worst performing perturbations
    print(f"\n{'='*50}")
    print("BEST AND WORST PERFORMING PERTURBATIONS")
    print("=" * 50)

    for model in models:
        model_data = per_class_df[per_class_df["Model"] == model]
        if len(model_data) > 0:
            print(f"\n{model}:")

            # Best performing
            best_pert = model_data.loc[model_data["Accuracy"].idxmax()]
            print(
                f"  Best:  Pert {best_pert['Perturbation_ID']} - {best_pert['Accuracy']:.4f} accuracy ({best_pert['Test_Samples']} samples)"
            )

            # Worst performing
            worst_pert = model_data.loc[model_data["Accuracy"].idxmin()]
            print(
                f"  Worst: Pert {worst_pert['Perturbation_ID']} - {worst_pert['Accuracy']:.4f} accuracy ({worst_pert['Test_Samples']} samples)"
            )

    # Calculate improvement from adding generated data
    if "Real→Real" in pivot_df.columns and "Real+Generated→Real" in pivot_df.columns:
        improvement = pivot_df["Real+Generated→Real"] - pivot_df["Real→Real"]
        print(f"\n{'='*50}")
        print("IMPROVEMENT FROM ADDING GENERATED DATA (UNBALANCED)")
        print("=" * 50)
        print(f"Per-Perturbation Improvement (Real+Generated→Real - Real→Real):")

        for pert_id in improvement.index:
            impr = improvement[pert_id]
            print(f"  Pert {pert_id}: {impr:+.4f}")

        print(f"\nSummary:")
        print(f"  Mean improvement: {improvement.mean():+.4f}")
        print(f"  Std improvement:  {improvement.std():.4f}")
        print(f"  Perturbations improved: {(improvement > 0).sum()}/{len(improvement)}")
        print(f"  Perturbations degraded: {(improvement < 0).sum()}/{len(improvement)}")

    # Calculate improvement from adding balanced generated data
    if (
        "Real→Real" in pivot_df.columns
        and "Balanced Real+Generated→Real" in pivot_df.columns
    ):
        balanced_improvement = (
            pivot_df["Balanced Real+Generated→Real"] - pivot_df["Real→Real"]
        )
        print(f"\n{'='*50}")
        print("IMPROVEMENT FROM ADDING GENERATED DATA (BALANCED)")
        print("=" * 50)
        print(
            f"Per-Perturbation Improvement (Balanced Real+Generated→Real - Real→Real):"
        )

        for pert_id in balanced_improvement.index:
            impr = balanced_improvement[pert_id]
            print(f"  Pert {pert_id}: {impr:+.4f}")

        print(f"\nSummary:")
        print(f"  Mean improvement: {balanced_improvement.mean():+.4f}")
        print(f"  Std improvement:  {balanced_improvement.std():.4f}")
        print(
            f"  Perturbations improved: {(balanced_improvement > 0).sum()}/{len(balanced_improvement)}"
        )
        print(
            f"  Perturbations degraded: {(balanced_improvement < 0).sum()}/{len(balanced_improvement)}"
        )

        # Compare unbalanced vs balanced improvements
        if "Real+Generated→Real" in pivot_df.columns:
            print(f"\n{'='*50}")
            print("COMPARISON: UNBALANCED vs BALANCED IMPROVEMENTS")
            print("=" * 50)
            comparison = balanced_improvement - improvement
            print(f"Per-Perturbation Difference (Balanced - Unbalanced):")

            for pert_id in comparison.index:
                diff = comparison[pert_id]
                print(f"  Pert {pert_id}: {diff:+.4f}")

            print(f"\nSummary:")
            print(f"  Mean difference: {comparison.mean():+.4f}")
            print(f"  Std difference:  {comparison.std():.4f}")
            print(
                f"  Balanced performs better: {(comparison > 0).sum()}/{len(comparison)}"
            )
            print(
                f"  Unbalanced performs better: {(comparison < 0).sum()}/{len(comparison)}"
            )

    print(f"\n📊 Per-class analysis saved to:")
    print(f"   - per_class_accuracies.csv: Detailed per-class results")
    print(f"   - per_class_accuracies_pivot.csv: Pivot table format")
    print(f"   - per_class_accuracy_analysis.png: Visualizations")

    return per_class_df, pivot_df


def main():
    # Set random seeds
    seed = 42
    seed_everything(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Setup device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load configuration and data module
    config_path = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"

    config = OmegaConf.load(config_path)
    datamodule = CellDataModule(config)

    # Initialize encoder
    encoder = OpenPhenomEncoder().to(device)

    # Get top perturbations for cell_type_id = 0
    cell_type_id = 0
    top_k = 10
    perturbation_ids = get_top_perturbations_for_cell_type(
        datamodule, cell_type_id=cell_type_id, top_k=top_k
    )

    print(f"\nAnalyzing perturbations: {perturbation_ids}")
    print(f"Cell type: {cell_type_id} (focusing on single cell type)")

    # Create output directory
    output_dir = (
        f"linear_model_results_celltype_{cell_type_id}_top_{top_k}_perturbations"
    )
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(f"{output_dir}/confusion_matrices", exist_ok=True)

    # Load and extract features
    print("Loading data and extracting features...")
    (
        real_features,
        real_pert_labels,
        generated_features,
        generated_pert_labels,
    ) = load_and_extract_features_single_cell_type(
        datamodule,
        generated_path,
        perturbation_ids,
        cell_type_id,
        encoder,
        device,
        max_samples_per_pert=500,
        batch_size=64,
    )

    print(f"\nFeature extraction complete!")
    print(f"Real features shape: {real_features.shape}")
    print(f"Generated features shape: {generated_features.shape}")

    # Check for data leakage and similarity issues
    similarity_stats = check_data_leakage_and_similarity(
        real_features, generated_features, output_dir
    )

    # Convert labels to numpy arrays
    real_pert_labels = np.array(real_pert_labels)
    generated_pert_labels = np.array(generated_pert_labels)

    # Encode labels for sklearn
    pert_encoder = LabelEncoder()

    # Fit on all labels (real + generated)
    all_pert_labels = np.concatenate([real_pert_labels, generated_pert_labels])
    pert_encoder.fit(all_pert_labels)

    # Transform labels
    real_pert_encoded = pert_encoder.transform(real_pert_labels)
    generated_pert_encoded = pert_encoder.transform(generated_pert_labels)

    print(f"Unique perturbations: {pert_encoder.classes_}")

    # Store results
    results = []

    # First, create a consistent train/test split for real data that will be reused
    real_train_features, real_test_features, real_train_labels, real_test_labels = (
        None,
        None,
        None,
        None,
    )
    if len(real_features) > 0:
        real_train_features, real_test_features, real_train_labels, real_test_labels = (
            train_test_split(
                real_features,
                real_pert_encoded,
                test_size=0.3,
                random_state=42,
                stratify=real_pert_encoded,
            )
        )

    # 1. Real Perturbation Predictor (train on real, test on real)
    if real_train_features is not None:
        result = train_and_evaluate_model(
            real_train_features,
            real_train_labels,
            real_test_features,
            real_test_labels,
            "Real→Real",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            real_test_labels,
            result["y_pred"],
            pert_encoder.classes_,
            f"Real→Real Perturbation Prediction (Cell Type {cell_type_id})",
            f"{output_dir}/confusion_matrices/real_real_perturbation.png",
        )

    # 2. Generated Perturbation Predictor (train on generated, test on generated)
    if len(generated_features) > 0:
        X_train, X_test, y_train, y_test = train_test_split(
            generated_features,
            generated_pert_encoded,
            test_size=0.3,
            random_state=42,
            stratify=generated_pert_encoded,
        )
        result = train_and_evaluate_model(
            X_train,
            y_train,
            X_test,
            y_test,
            "Generated→Generated",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            y_test,
            result["y_pred"],
            pert_encoder.classes_,
            f"Generated→Generated Perturbation Prediction (Cell Type {cell_type_id})",
            f"{output_dir}/confusion_matrices/gen_gen_perturbation.png",
        )

    # 3. Cross-domain: Train on generated, test on real (Perturbation)
    if len(generated_features) > 0 and len(real_features) > 0:
        result = train_and_evaluate_model(
            generated_features,
            generated_pert_encoded,
            real_features,
            real_pert_encoded,
            "Generated→Real",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            result["y_true"],
            result["y_pred"],
            pert_encoder.classes_,
            f"Generated→Real Perturbation Prediction (Cell Type {cell_type_id})",
            f"{output_dir}/confusion_matrices/gen_real_perturbation.png",
        )

    # 4. Combined training: Train on real+generated, test on real (FIXED VERSION)
    if len(generated_features) > 0 and real_train_features is not None:
        # Use the SAME train/test split as Real→Real to avoid data leakage
        # Combine ONLY the training portion of real data with all generated data
        combined_train_features = np.vstack([real_train_features, generated_features])
        combined_train_labels = np.concatenate(
            [real_train_labels, generated_pert_encoded]
        )

        # Use the SAME test set as Real→Real
        X_test = real_test_features
        y_test = real_test_labels

        # Check for train/test overlap
        overlap_info = analyze_train_test_overlap(
            combined_train_features, X_test, "Real+Generated→Real", output_dir
        )

        result = train_and_evaluate_model(
            combined_train_features,
            combined_train_labels,
            X_test,
            y_test,
            "Real+Generated→Real",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            result["y_true"],
            result["y_pred"],
            pert_encoder.classes_,
            f"Real+Generated→Real Perturbation Prediction (Cell Type {cell_type_id})",
            f"{output_dir}/confusion_matrices/combined_real_perturbation.png",
        )

    # 5. Balanced Combined training: Train on balanced real+generated, test on real
    if len(generated_features) > 0 and real_train_features is not None:
        print("\n" + "=" * 60)
        print("BALANCED REAL+GENERATED→REAL EXPERIMENT")
        print("=" * 60)

        # Balance the real and generated samples
        n_real_train = len(real_train_features)
        n_generated = len(generated_features)

        print(f"Available real training samples: {n_real_train}")
        print(f"Available generated samples: {n_generated}")

        # Use the smaller number to balance
        n_balanced = min(n_real_train, n_generated)
        print(f"Using {n_balanced} samples from each (real and generated)")

        # Sample from real training data
        if n_real_train > n_balanced:
            # Set seed for reproducible sampling
            np.random.seed(42)
            real_indices = np.random.choice(n_real_train, n_balanced, replace=False)
            balanced_real_features = real_train_features[real_indices]
            balanced_real_labels = real_train_labels[real_indices]
        else:
            balanced_real_features = real_train_features
            balanced_real_labels = real_train_labels

        # Sample from generated data
        if n_generated > n_balanced:
            # Set seed for reproducible sampling
            np.random.seed(42)
            gen_indices = np.random.choice(n_generated, n_balanced, replace=False)
            balanced_gen_features = generated_features[gen_indices]
            balanced_gen_labels = generated_pert_encoded[gen_indices]
        else:
            balanced_gen_features = generated_features
            balanced_gen_labels = generated_pert_encoded

        # Combine balanced datasets
        balanced_combined_features = np.vstack(
            [balanced_real_features, balanced_gen_features]
        )
        balanced_combined_labels = np.concatenate(
            [balanced_real_labels, balanced_gen_labels]
        )

        print(f"Balanced training set: {len(balanced_combined_features)} samples")
        print(f"  - Real: {len(balanced_real_features)} samples")
        print(f"  - Generated: {len(balanced_gen_features)} samples")

        # Use the SAME test set as Real→Real
        X_test = real_test_features
        y_test = real_test_labels

        # Check for train/test overlap
        overlap_info_balanced = analyze_train_test_overlap(
            balanced_combined_features,
            X_test,
            "Balanced Real+Generated→Real",
            output_dir,
        )

        result = train_and_evaluate_model(
            balanced_combined_features,
            balanced_combined_labels,
            X_test,
            y_test,
            "Balanced Real+Generated→Real",
            "Perturbation Prediction",
        )
        results.append(result)

        # Plot confusion matrix
        plot_confusion_matrix(
            result["y_true"],
            result["y_pred"],
            pert_encoder.classes_,
            f"Balanced Real+Generated→Real Perturbation Prediction (Cell Type {cell_type_id})",
            f"{output_dir}/confusion_matrices/balanced_combined_real_perturbation.png",
        )

    # 6. Compare training sizes: Train on real only vs real+generated (REMOVED - REDUNDANT)
    # This comparison is now redundant since "Real→Real" and "Real+Generated→Real"
    # use the same test set and provide a fair comparison

    # Create summary dataframe
    summary_data = []
    for result in results:
        summary_data.append(
            {
                "Model": result["model_name"],
                "Task": result["task"],
                "Test_Accuracy": result["accuracy"],
                "CV_Mean": result["cv_mean"],
                "CV_Std": result["cv_std"],
                "N_Train": result["n_train"],
                "N_Test": result["n_test"],
                "N_Classes": result["n_classes"],
            }
        )

    summary_df = pd.DataFrame(summary_data)

    # Save summary
    summary_df.to_csv(f"{output_dir}/model_performance_summary.csv", index=False)

    # Print summary
    print("\n" + "=" * 80)
    print("FINAL RESULTS SUMMARY")
    print("=" * 80)
    print(summary_df.to_string(index=False))

    # Generate per-class accuracy analysis
    per_class_df, pivot_df = create_per_class_accuracy_report(
        results, pert_encoder, output_dir
    )

    # Create visualization
    plt.figure(figsize=(14, 8))

    # Plot accuracies
    x_pos = np.arange(len(summary_df))
    width = 0.35

    fig, ax = plt.subplots(1, 1, figsize=(16, 8))

    # Test accuracy vs CV accuracy
    ax.bar(
        x_pos - width / 2,
        summary_df["Test_Accuracy"],
        width,
        label="Test Accuracy",
        alpha=0.8,
        color="skyblue",
    )
    ax.bar(
        x_pos + width / 2,
        summary_df["CV_Mean"],
        width,
        label="CV Mean",
        alpha=0.8,
        color="orange",
        yerr=summary_df["CV_Std"],
        capsize=5,
    )

    ax.set_xlabel("Model Type")
    ax.set_ylabel("Accuracy")
    ax.set_title(
        f"Perturbation Prediction Performance (Cell Type {cell_type_id}, Top {top_k} Perturbations)"
    )
    ax.set_xticks(x_pos)
    ax.set_xticklabels(summary_df["Model"], rotation=45, ha="right")
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1.0)

    # Add value labels on bars
    for i, (test_acc, cv_acc) in enumerate(
        zip(summary_df["Test_Accuracy"], summary_df["CV_Mean"])
    ):
        ax.text(
            i - width / 2,
            test_acc + 0.01,
            f"{test_acc:.3f}",
            ha="center",
            va="bottom",
            fontsize=10,
        )
        ax.text(
            i + width / 2,
            cv_acc + 0.01,
            f"{cv_acc:.3f}",
            ha="center",
            va="bottom",
            fontsize=10,
        )

    plt.tight_layout()
    plt.savefig(
        f"{output_dir}/performance_comparison.png", dpi=300, bbox_inches="tight"
    )
    plt.close()

    # Create sample size comparison plot
    plt.figure(figsize=(12, 6))

    # Bar plot of training sample sizes
    models = summary_df["Model"]
    train_sizes = summary_df["N_Train"]
    test_sizes = summary_df["N_Test"]

    x_pos = np.arange(len(models))
    width = 0.35

    plt.bar(
        x_pos - width / 2,
        train_sizes,
        width,
        label="Training Samples",
        alpha=0.8,
        color="lightgreen",
    )
    plt.bar(
        x_pos + width / 2,
        test_sizes,
        width,
        label="Test Samples",
        alpha=0.8,
        color="lightcoral",
    )

    plt.xlabel("Model Type")
    plt.ylabel("Number of Samples")
    plt.title(f"Training and Test Sample Sizes (Cell Type {cell_type_id})")
    plt.xticks(x_pos, models, rotation=45, ha="right")
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Add value labels
    for i, (train_size, test_size) in enumerate(zip(train_sizes, test_sizes)):
        plt.text(
            i - width / 2,
            train_size + max(train_sizes) * 0.01,
            str(train_size),
            ha="center",
            va="bottom",
            fontsize=9,
        )
        plt.text(
            i + width / 2,
            test_size + max(train_sizes) * 0.01,
            str(test_size),
            ha="center",
            va="bottom",
            fontsize=9,
        )

    plt.tight_layout()
    plt.savefig(
        f"{output_dir}/sample_sizes_comparison.png", dpi=300, bbox_inches="tight"
    )
    plt.close()

    # Save detailed classification reports
    with open(f"{output_dir}/detailed_reports.txt", "w") as f:
        f.write(
            f"Analysis of Cell Type {cell_type_id} with Top {top_k} Perturbations\n"
        )
        f.write(f"Perturbation IDs: {perturbation_ids}\n\n")

        for i, result in enumerate(results):
            f.write(f"\n{'='*60}\n")
            f.write(f"{result['model_name']} - {result['task']}\n")
            f.write(f"{'='*60}\n")

            labels = pert_encoder.classes_

            report = classification_report(
                result["y_true"],
                result["y_pred"],
                target_names=[f"Pert_{l}" for l in labels],
            )
            f.write(report)
            f.write(f"\n")

    # Save perturbation information
    with open(f"{output_dir}/perturbation_info.txt", "w") as f:
        f.write(f"Top {top_k} Perturbations for Cell Type {cell_type_id}:\n\n")
        distribution = get_perturbation_cell_type_distribution_custom(datamodule)

        cell_type_counts = {}
        for (pert_id, ct_id), count in distribution.items():
            if ct_id == cell_type_id:
                cell_type_counts[pert_id] = count

        sorted_perturbations = sorted(
            cell_type_counts.items(), key=lambda x: x[1], reverse=True
        )
        total_samples = len(datamodule.metadata[datamodule.metadata["train_index"]])

        for i, (pert_id, count) in enumerate(sorted_perturbations[:top_k]):
            f.write(
                f"{i+1:2d}. Perturbation {pert_id:4d}: {count:4d} samples ({count/total_samples*100:.2f}%)\n"
            )

    print(f"\nResults saved to {output_dir}/")
    print("- model_performance_summary.csv: Summary of all model performances")
    print("- per_class_accuracies.csv: Detailed per-class accuracy results")
    print(
        "- per_class_accuracies_pivot.csv: Per-class accuracies in pivot table format"
    )
    print("- performance_comparison.png: Visualization of overall results")
    print("- per_class_accuracy_analysis.png: Per-class accuracy visualizations")
    print("- sample_sizes_comparison.png: Training/test sample size comparison")
    print("- similarity_analysis.png: Real vs Generated similarity analysis")
    print("- similarity_statistics.txt: Detailed similarity statistics")
    print("- confusion_matrices/: Individual confusion matrices")
    print("- detailed_reports.txt: Detailed classification reports")
    print("- perturbation_info.txt: Information about selected perturbations")
    print(
        f"\nAnalysis focused on Cell Type {cell_type_id} with top {top_k} perturbations by sample count"
    )

    # Final warning about suspicious results
    if any(result["accuracy"] > 0.95 for result in results):
        print(f"\n{'⚠️ '*3} WARNING {'⚠️ '*3}")
        print("Very high accuracy detected (>95%)!")
        print("Please check the similarity analysis for potential data leakage.")
        print(
            "High similarity between generated and real samples could explain these results."
        )
        print("=" * 60)


if __name__ == "__main__":
    main()
