#!/usr/bin/env python3
"""
Script to compare R script generated clustering summaries across multiple result folders.

Usage:
    python compare_clustering_results.py folder1 folder2 folder3 ...
    python compare_clustering_results.py results/*  # Compare all result folders
"""

import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import argparse
from PIL import Image
import numpy as np
import subprocess


def run_r_clustering_script(folder_path):
    """Run the R clustering script on a folder to generate metrics."""
    folder = Path(folder_path).resolve()  # Get absolute path

    # Use the  script
    script_name = "clustering_metrics_with_.R"

    # Find the script relative to the current working directory or absolute path
    current_dir = Path.cwd()
    script_path = current_dir / "scripts" / script_name

    # If not found, try looking in the parent directories
    if not script_path.exists():
        # Try going up directories to find the scripts folder
        test_dir = current_dir
        for _ in range(3):  # Try up to 3 levels up
            test_script = test_dir / "scripts" / script_name
            if test_script.exists():
                script_path = test_script
                break
            test_dir = test_dir.parent

    # Check if the script exists
    if not script_path.exists():
        print(f"Warning: R script not found at {script_path}")
        print(f"Current directory: {current_dir}")
        print(f"Looking for script: {script_name}")
        return False

    # Check if required input files exist
    required_files = ["denoised_embeddings.csv", "cleaned_cell_labels_meta_tea_seq.csv"]
    missing_files = [f for f in required_files if not (folder / f).exists()]

    if missing_files:
        print(f"Warning: Missing required files in {folder.name}: {missing_files}")
        return False

    print(f"Running R clustering script on {folder.name}...")

    try:
        # Run the R script - now takes folder path as argument and can run from anywhere
        result = subprocess.run(
            ["Rscript", str(script_path), str(folder)],
            cwd=current_dir,  # Run from project root
            capture_output=True,
            text=True,
            timeout=300,  # 5 minute timeout
        )

        if result.returncode == 0:
            print(f"✓ R script completed successfully for {folder.name}")

            # Check if the expected output files were actually created
            files_after = find_clustering_files(folder)
            if "metrics_csv" not in files_after:
                print(
                    f"⚠️  Warning: {script_name} ran but didn't create clustering_metrics_results.csv"
                )
                return False

            return True
        else:
            print(f"✗ R script failed for {folder.name}")
            print(f"Error output: {result.stderr}")
            if result.stdout.strip():
                print(f"Standard output: {result.stdout.strip()}")
            return False

    except subprocess.TimeoutExpired:
        print(f"✗ R script timed out for {folder.name}")
        return False
    except Exception as e:
        print(f"✗ Error running R script for {folder.name}: {e}")
        return False


def find_clustering_files(folder_path):
    """Find clustering analysis files in a folder."""
    folder = Path(folder_path)

    files = {
        "metrics_csv": folder / "clustering_metrics_results.csv",
        "kmeans_plot": folder / "clustering_analysis_kmeans.png",
        "true_labels_plot": folder / "clustering_analysis_true_labels.png",
    }

    # Check which files exist
    existing_files = {k: v for k, v in files.items() if v.exists()}

    return existing_files


def load_metrics_data(csv_paths):
    """Load and combine metrics data from multiple CSV files."""
    all_metrics = []

    for folder_name, csv_path in csv_paths.items():
        if csv_path.exists():
            try:
                df = pd.read_csv(csv_path)
                df["Folder"] = folder_name
                all_metrics.append(df)
            except Exception as e:
                print(f"Warning: Could not read {csv_path}: {e}")

    if all_metrics:
        combined_df = pd.concat(all_metrics, ignore_index=True)
        return combined_df
    else:
        return pd.DataFrame()


def create_metrics_comparison_plot(df, output_path):
    """Create comparison plots for clustering metrics."""
    if df.empty:
        print("No metrics data to plot.")
        return

    metrics = df["Metric"].unique()
    n_metrics = len(metrics)

    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        if i < len(axes):
            metric_data = df[df["Metric"] == metric]

            ax = axes[i]
            bars = ax.bar(
                range(len(metric_data)),
                metric_data["Value"],
                color=plt.cm.Set3(np.linspace(0, 1, len(metric_data))),
            )

            ax.set_title(f"{metric}", fontsize=12, fontweight="bold")
            ax.set_ylabel("Value")
            ax.set_xticks(range(len(metric_data)))
            ax.set_xticklabels(metric_data["Folder"], rotation=45, ha="right")
            ax.grid(True, alpha=0.3)

            # Add value labels on bars
            for bar, value in zip(bars, metric_data["Value"]):
                height = bar.get_height()
                ax.text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height + 0.01,
                    f"{value:.3f}",
                    ha="center",
                    va="bottom",
                    fontsize=10,
                )

    # Hide unused subplots
    for i in range(len(metrics), len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.suptitle(
        "Clustering Metrics Comparison Across Experiments",
        fontsize=16,
        fontweight="bold",
        y=0.98,
    )
    plt.subplots_adjust(top=0.92)

    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"Metrics comparison plot saved to: {output_path}")
    plt.close()


def create_summary_table(df, output_path):
    """Create a summary table of all metrics."""
    if df.empty:
        print("No metrics data for summary table.")
        return

    # Pivot the dataframe to have metrics as columns and folders as rows
    pivot_df = df.pivot(index="Folder", columns="Metric", values="Value")

    # Round values for better display
    pivot_df = pivot_df.round(4)

    # Save as CSV
    csv_path = str(output_path).replace(".png", ".csv")
    pivot_df.to_csv(csv_path)
    print(f"Summary table saved to: {csv_path}")

    # Create a heatmap
    plt.figure(figsize=(12, 8))
    sns.heatmap(
        pivot_df,
        annot=True,
        cmap="RdYlBu_r",
        center=0.5,
        cbar_kws={"label": "Metric Value"},
        fmt=".3f",
    )
    plt.title("Clustering Metrics Heatmap", fontsize=16, fontweight="bold")
    plt.xlabel("Metrics")
    plt.ylabel("Experiments")
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()

    heatmap_path = str(output_path).replace(".png", "_heatmap.png")
    plt.savefig(heatmap_path, dpi=300, bbox_inches="tight")
    print(f"Metrics heatmap saved to: {heatmap_path}")
    plt.close()

    return pivot_df


def create_image_comparison(image_paths, image_type, output_path):
    """Create a side-by-side comparison of images."""
    valid_images = [(name, path) for name, path in image_paths.items() if path.exists()]

    if not valid_images:
        print(f"No {image_type} images found for comparison.")
        return

    n_images = len(valid_images)
    cols = min(4, n_images)  # Max 4 columns
    rows = (n_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))

    if n_images == 1:
        axes = [axes]
    elif rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()

    for i, (folder_name, img_path) in enumerate(valid_images):
        try:
            img = Image.open(img_path)
            if isinstance(axes, list):
                ax = axes[i]
            elif isinstance(axes, np.ndarray):
                ax = axes.flat[i] if axes.ndim > 0 else axes
            else:
                ax = axes
            ax.imshow(np.array(img))
            ax.set_title(folder_name, fontsize=12, fontweight="bold")
            ax.axis("off")
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")

    # Hide unused subplots
    if isinstance(axes, list):
        for i in range(len(valid_images), len(axes)):
            axes[i].set_visible(False)
    elif isinstance(axes, np.ndarray):
        for i in range(len(valid_images), len(axes.flat)):
            axes.flat[i].set_visible(False)

    plt.tight_layout()
    plt.suptitle(
        f'{image_type.replace("_", " ").title()} Comparison',
        fontsize=16,
        fontweight="bold",
        y=0.98,
    )
    plt.subplots_adjust(top=0.92)

    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"{image_type} comparison saved to: {output_path}")
    plt.close()


def main():
    parser = argparse.ArgumentParser(
        description="Compare R script generated clustering summaries across result folders"
    )
    parser.add_argument("folders", nargs="+", help="Result folders to compare")
    parser.add_argument(
        "--output-dir",
        "-o",
        default="clustering_comparison",
        help="Output directory for comparison results",
    )
    parser.add_argument(
        "--auto-generate",
        "-g",
        action="store_true",
        help="Automatically run R clustering script if metrics CSV is missing",
    )

    args = parser.parse_args()

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True)

    print(f"Comparing clustering results from {len(args.folders)} folders...")
    print(f"Output directory: {output_dir}")

    # Find clustering files in each folder
    folder_files = {}
    csv_paths = {}
    kmeans_paths = {}
    true_labels_paths = {}

    for folder in args.folders:
        folder_path = Path(folder)
        if not folder_path.exists():
            print(f"Warning: Folder {folder} does not exist, skipping...")
            continue

        folder_name = folder_path.name
        files = find_clustering_files(folder_path)

        # Check if metrics CSV is missing and auto-generate is enabled
        missing_csv = "metrics_csv" not in files

        if missing_csv and args.auto_generate:
            print(f"Metrics CSV missing in {folder_name}, attempting to generate...")
            success = run_r_clustering_script(folder_path)

            if success:
                # Re-check for files after running R script
                files = find_clustering_files(folder_path)
            else:
                print(f"Failed to generate metrics for {folder_name}")

        if files:
            folder_files[folder_name] = files
            if "metrics_csv" in files:
                csv_paths[folder_name] = files["metrics_csv"]
            if "kmeans_plot" in files:
                kmeans_paths[folder_name] = files["kmeans_plot"]
            if "true_labels_plot" in files:
                true_labels_paths[folder_name] = files["true_labels_plot"]

            print(f"Found {len(files)} clustering files in {folder_name}")
        else:
            if missing_csv and not args.auto_generate:
                print(
                    f"No clustering files found in {folder_name} (use --auto-generate to create them)"
                )
            else:
                print(f"No clustering files found in {folder_name}")

    if not folder_files:
        print("No clustering analysis files found in any folder.")
        return

    # Load and compare metrics data
    if csv_paths:
        print(f"\nProcessing metrics from {len(csv_paths)} folders...")
        metrics_df = load_metrics_data(csv_paths)

        if not metrics_df.empty:
            # Create comparison plots
            metrics_plot_path = output_dir / "metrics_comparison.png"
            create_metrics_comparison_plot(metrics_df, metrics_plot_path)

            # Create summary table and heatmap
            summary_table_path = output_dir / "metrics_summary.png"
            summary_df = create_summary_table(metrics_df, summary_table_path)

            print(f"\nSummary of clustering metrics:")
            print(summary_df.to_string())

    # Create image comparisons
    if kmeans_paths:
        print(
            f"\nCreating K-means clustering comparison from {len(kmeans_paths)} images..."
        )
        kmeans_comparison_path = output_dir / "kmeans_clustering_comparison.png"
        create_image_comparison(
            kmeans_paths, "kmeans_clustering", kmeans_comparison_path
        )

    if true_labels_paths:
        print(
            f"\nCreating true labels comparison from {len(true_labels_paths)} images..."
        )
        true_labels_comparison_path = output_dir / "true_labels_comparison.png"
        create_image_comparison(
            true_labels_paths, "true_labels", true_labels_comparison_path
        )

    print(f"\nComparison complete! Results saved in: {output_dir}")

    # Print summary of what was found
    print(f"\nSummary:")
    print(f"- Folders analyzed: {len(folder_files)}")
    print(f"- Metrics CSV files: {len(csv_paths)}")
    print(f"- K-means plot files: {len(kmeans_paths)}")
    print(f"- True labels plot files: {len(true_labels_paths)}")


if __name__ == "__main__":
    main()
