#!/usr/bin/env python3
"""
Script to analyze grid search results from mixed diffusion experiments.

This script searches through result directories, parses configuration files,
and extracts ARI (Adjusted Rand Index) scores before and after denoising
for different parameter combinations.
"""

import os
import json
import pandas as pd
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import re

# Optional plotting imports
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    from matplotlib.patches import Rectangle

    PLOTTING_AVAILABLE = True
except ImportError:
    PLOTTING_AVAILABLE = False


def _handle_plot_display(plots_dir: Path, filename: str, visualize: bool):
    """Helper function to handle plot display/saving."""
    if visualize:
        plt.show()
    else:
        plt.savefig(plots_dir / filename, dpi=300, bbox_inches="tight")


def parse_directory_name(dir_name: str) -> Dict[str, str]:
    """
    Parse the directory name to extract parameter values.

    Format: {config}_gibbs{gibbs_iter}_rho{rho_start}-{rho_end}_rep{rep}
    Example: medium_gibbs100_rho0.5-0.5_rep1
             medium_mixup_long_gibbs100_rho0.5-0.5_rep1

    Args:
        dir_name: Directory name to parse

    Returns:
        Dictionary with extracted parameters
    """
    # Pattern to match the directory naming convention
    # More flexible pattern to capture any config name before _gibbs
    pattern = r"(.+?)_gibbs(\d+)_rho([0-9.]+)-([0-9.]+)_rep(\d+)"
    match = re.match(pattern, dir_name)

    if match:
        config, gibbs_iter, rho_start, rho_end, rep = match.groups()
        return {
            "config": config,
            "gibbs_iterations": int(gibbs_iter),
            "rho_start": float(rho_start),
            "rho_end": float(rho_end),
            "repeated_sampling_factor": int(rep),
        }
    else:
        return {}


def load_results_file(results_path: Path) -> Optional[Dict]:
    """
    Load and parse the results.json file.

    Args:
        results_path: Path to results.json file

    Returns:
        Dictionary with results or None if file doesn't exist
    """
    if results_path.exists():
        try:
            with open(results_path, "r") as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Error loading {results_path}: {e}")
            return None
    return None


def load_clustering_results_csv(csv_path: Path) -> Optional[Dict]:
    """
    Load and parse the clustering_metrics_results.csv file.

    Args:
        csv_path: Path to clustering_metrics_results.csv file

    Returns:
        Dictionary with results or None if file doesn't exist
    """
    if csv_path.exists():
        try:
            import pandas as pd

            df = pd.read_csv(csv_path)
            print(f"CSV columns: {list(df.columns)}")
            print(f"CSV shape: {df.shape}")
            print(
                f"Unique datasets: {df['Dataset'].unique() if 'Dataset' in df.columns else 'No Dataset column'}"
            )
            print(
                f"Unique metrics: {df['Metric'].unique() if 'Metric' in df.columns else 'No Metric column'}"
            )

            # Convert DataFrame to dictionary format compatible with existing code
            results = {}

            # Extract ARI values for original and denoised data
            # Look for "Adjusted_Rand_Index" metric and "True"/"Denoised" datasets
            original_ari = None
            denoised_ari = None

            if "Adjusted_Rand_Index" in df["Metric"].values:
                # Look for True dataset (original)
                true_ari_rows = df[
                    (df["Metric"] == "Adjusted_Rand_Index") & (df["Dataset"] == "True")
                ]
                if not true_ari_rows.empty:
                    original_ari = true_ari_rows["Value"].iloc[0]

                # Look for Denoised dataset
                denoised_ari_rows = df[
                    (df["Metric"] == "Adjusted_Rand_Index")
                    & (df["Dataset"] == "Denoised")
                ]
                if not denoised_ari_rows.empty:
                    denoised_ari = denoised_ari_rows["Value"].iloc[0]

            # Fallback: look for ARI metric (old format)
            elif "ARI" in df["Metric"].values:
                ari_rows = df[df["Metric"] == "ARI"]
                if len(ari_rows) >= 1:
                    original_ari = ari_rows["Value"].iloc[0]
                if len(ari_rows) >= 2:
                    denoised_ari = ari_rows["Value"].iloc[1]

            results["ari_original"] = original_ari
            results["ari_denoised"] = denoised_ari

            print(f"Extracted ARI - Original: {original_ari}, Denoised: {denoised_ari}")

            # Add other metrics if available
            for _, row in df.iterrows():
                dataset = row.get("Dataset", "unknown")
                metric = row["Metric"]
                value = row["Value"]

                # Create meaningful key names
                key = f"{dataset.lower()}_{metric.lower().replace(' ', '_').replace('_', '_')}"
                results[key] = value

            return results

        except Exception as e:
            print(f"Error loading {csv_path}: {e}")
            import traceback

            traceback.print_exc()
            return None
    return None


def load_config_file(config_path: Path) -> Optional[Dict]:
    """
    Load and parse the args.json configuration file.

    Args:
        config_path: Path to args.json file

    Returns:
        Dictionary with configuration or None if file doesn't exist
    """
    if config_path.exists():
        try:
            with open(config_path, "r") as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Error loading {config_path}: {e}")
            return None
    return None


def scan_results_directory(base_path: Path, models: List[str]) -> List[Dict]:
    """
    Scan the results directory for all experiment results.

    Args:
        base_path: Base path to the results directory
        models: List of model names to scan

    Returns:
        List of dictionaries containing all experiment data
    """
    all_results = []

    for model in models:
        model_path = base_path / model
        if not model_path.exists():
            print(f"Model directory {model_path} not found, skipping...")
            continue

        print(f"Scanning model: {model}")

        for experiment_dir in model_path.iterdir():
            if not experiment_dir.is_dir():
                continue

            # Parse directory name to extract parameters
            params = parse_directory_name(experiment_dir.name)
            if not params:
                print(f"Could not parse directory name: {experiment_dir.name}")
                continue

            # Look for clustering_metrics_results.csv first, then fall back to results.json
            csv_results_file = experiment_dir / "clustering_metrics_results.csv"
            json_results_file = experiment_dir / "results.json"
            config_file = experiment_dir / "args.json"

            # Try to load from CSV first, then JSON
            results_data = None
            if csv_results_file.exists():
                results_data = load_clustering_results_csv(csv_results_file)
                if results_data:
                    print(f"Loaded results from CSV: {csv_results_file}")

            if results_data is None and json_results_file.exists():
                results_data = load_results_file(json_results_file)
                if results_data:
                    print(f"Loaded results from JSON: {json_results_file}")

            config_data = load_config_file(config_file)

            if results_data is None:
                print(
                    f"No results file found in {experiment_dir} (looked for both CSV and JSON)"
                )
                continue

            # Combine all data
            experiment_data = {
                "model": model,
                "experiment_dir": str(experiment_dir),
                "ari_original": results_data.get("ari_original"),
                "ari_denoised": results_data.get("ari_denoised"),
                "ari_improvement": (
                    (results_data.get("ari_denoised") or 0)
                    - (results_data.get("ari_original") or 0)
                    if results_data.get("ari_denoised") is not None
                    and results_data.get("ari_original") is not None
                    else None
                ),
                **params,
            }

            # Add all metrics from results_data (including CSV metrics)
            for key, value in results_data.items():
                if key not in experiment_data:  # Don't overwrite existing keys
                    experiment_data[key] = value

            # Add selected config parameters if available
            if config_data:
                config_fields = [
                    "hidden_dim",
                    "num_blocks",
                    "noise_step",
                    "step_size",
                    "burn_in",
                    "test_noise_level",
                    "visualization_method",
                ]
                for field in config_fields:
                    if field in config_data:
                        experiment_data[f"config_{field}"] = config_data[field]

            all_results.append(experiment_data)

    return all_results


def create_summary_statistics(df: pd.DataFrame) -> pd.DataFrame:
    """
    Create summary statistics grouped by key parameters.

    Args:
        df: DataFrame with all experiment results

    Returns:
        DataFrame with summary statistics
    """
    # Group by key parameters and calculate statistics
    groupby_cols = [
        "model",
        "config",
        "gibbs_iterations",
        "rho_start",
        "rho_end",
        "repeated_sampling_factor",
    ]

    # Identify all metric columns in the DataFrame
    metric_columns = []

    # Standard ARI metrics
    ari_metrics = ["ari_original", "ari_denoised", "ari_improvement"]
    metric_columns.extend([col for col in ari_metrics if col in df.columns])

    # Find all other metrics from CSV (denoised_* and true_*)
    for col in df.columns:
        if col.startswith(("denoised_", "true_")) and col not in metric_columns:
            metric_columns.append(col)

    print(f"Creating summary for metrics: {metric_columns}")

    # Create aggregation dictionary for all found metrics
    agg_dict = {}
    for metric in metric_columns:
        if metric in df.columns:
            agg_dict[metric] = ["mean", "std", "min", "max"]
            if metric == "ari_original":  # Add count for one representative metric
                agg_dict[metric].append("count")

    summary = df.groupby(groupby_cols).agg(agg_dict).round(4)

    # Flatten column names
    summary.columns = ["_".join(col).strip() for col in summary.columns]

    return summary.reset_index()


def analyze_best_configurations(
    df: pd.DataFrame, metric: str = "ari_improvement"
) -> pd.DataFrame:
    """
    Find the best performing configurations.

    Args:
        df: DataFrame with all experiment results
        metric: Metric to optimize ('ari_improvement', 'ari_denoised', etc.)

    Returns:
        DataFrame with top performing configurations
    """
    # Calculate mean performance for each configuration
    groupby_cols = [
        "model",
        "config",
        "gibbs_iterations",
        "rho_start",
        "rho_end",
        "repeated_sampling_factor",
    ]

    best_configs = (
        df.groupby(groupby_cols)
        .agg(
            {
                "ari_original": "mean",
                "ari_denoised": "mean",
                "ari_improvement": "mean",
                "experiment_dir": "count",  # Number of runs with this exact configuration
            }
        )
        .round(4)
    )

    best_configs.columns = [
        "avg_ari_original",
        "avg_ari_denoised",
        "avg_ari_improvement",
        "num_runs",
    ]

    # Sort by the specified metric
    sort_column = f"avg_{metric}" if f"avg_{metric}" in best_configs.columns else metric
    best_configs = best_configs.sort_values(sort_column, ascending=False)

    return best_configs.reset_index()


def save_analysis_results(
    df: pd.DataFrame,
    summary: pd.DataFrame,
    best_configs: pd.DataFrame,
    output_dir: Path,
):
    """
    Save all analysis results to files.

    Args:
        df: Complete experiment results DataFrame
        summary: Summary statistics DataFrame
        best_configs: Best configurations DataFrame
        output_dir: Directory to save results
    """
    output_dir.mkdir(exist_ok=True)

    # Save complete results
    df.to_csv(output_dir / "complete_results.csv", index=False)
    print(f"Complete results saved to: {output_dir / 'complete_results.csv'}")

    # Save summary statistics
    summary.to_csv(output_dir / "summary_statistics.csv", index=False)
    print(f"Summary statistics saved to: {output_dir / 'summary_statistics.csv'}")

    # Save best configurations
    best_configs.to_csv(output_dir / "best_configurations.csv", index=False)
    print(f"Best configurations saved to: {output_dir / 'best_configurations.csv'}")


def create_performance_plots(
    df: pd.DataFrame, output_dir: Path, visualize: bool = False
):
    """
    Create comprehensive performance visualization plots for each model separately.
    Never aggregates across synthetic_train_test and synthetic_easy models.

    Args:
        df: Complete experiment results DataFrame
        output_dir: Directory to save plots
        visualize: If True, display plots interactively instead of just saving
    """
    if not PLOTTING_AVAILABLE:
        raise ImportError("matplotlib and seaborn are required for plotting")

    # Set style for better looking plots
    plt.style.use("default")
    sns.set_palette("husl")

    # Create plots subdirectory
    plots_dir = output_dir / "plots"
    plots_dir.mkdir(exist_ok=True)

    # Create separate plots for each model
    models = df["model"].unique()

    for model in models:
        print(f"Creating plots for model: {model}")
        model_df = df[df["model"] == model].copy()
        model_plots_dir = plots_dir / model
        model_plots_dir.mkdir(exist_ok=True)

        # 1. ARI Improvement Heatmap by Config (model-specific)
        create_ari_heatmap_per_model(model_df, model_plots_dir, model, visualize)

        # 2. Performance vs Gibbs Iterations (model-specific)
        create_gibbs_performance_plot_per_model(
            model_df, model_plots_dir, model, visualize
        )

        # 3. Performance vs Rho Parameters (model-specific)
        create_rho_performance_plot_per_model(
            model_df, model_plots_dir, model, visualize
        )

        # 4. Performance vs Repeated Sampling Factor (model-specific)
        create_sampling_performance_plot_per_model(
            model_df, model_plots_dir, model, visualize
        )

        # 5. Overall Performance Distribution (model-specific)
        create_performance_distribution_plot_per_model(
            model_df, model_plots_dir, model, visualize
        )

        # 6. Best Configuration Comparison (model-specific)
        create_best_config_comparison_per_model(
            model_df, model_plots_dir, model, visualize
        )

        # 7. Parameter Interaction Effects (model-specific)
        create_parameter_interaction_plot_per_model(
            model_df, model_plots_dir, model, visualize
        )

    # Also create a cross-model comparison plot (but clearly separated)
    create_cross_model_comparison(df, plots_dir, visualize)

    print(f"📊 Performance plots saved to: {plots_dir} (separated by model)")

    print(f"📊 Performance plots saved to: {plots_dir}")


def create_ari_heatmap_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create heatmap showing ARI improvement by config for a specific model."""
    plt.figure(figsize=(10, 6))

    # Aggregate by config only (since we're already filtered to one model)
    pivot_data = df.groupby(["config"])["ari_improvement"].mean().reset_index()

    # Reshape for heatmap (create a single row)
    pivot_table = pivot_data.set_index("config").T

    # Create heatmap
    sns.heatmap(
        pivot_table,
        annot=True,
        fmt=".4f",
        cmap="RdYlBu_r",
        center=0,
        cbar_kws={"label": "Mean ARI Improvement"},
    )

    plt.title(
        f"ARI Improvement by Configuration - {model_name}",
        fontsize=16,
        fontweight="bold",
    )
    plt.xlabel("Configuration", fontsize=12)
    plt.ylabel("")
    plt.tight_layout()

    _handle_plot_display(plots_dir, "ari_improvement_heatmap.png", visualize)
    plt.close()


def create_ari_heatmap(df: pd.DataFrame, plots_dir: Path):
    """Create heatmap showing ARI improvement by model and config."""
    plt.figure(figsize=(12, 8))

    # Aggregate by model and config
    pivot_data = df.groupby(["model", "config"])["ari_improvement"].mean().reset_index()
    pivot_table = pivot_data.pivot(
        index="model", columns="config", values="ari_improvement"
    )

    # Create heatmap
    sns.heatmap(
        pivot_table,
        annot=True,
        fmt=".4f",
        cmap="RdYlBu_r",
        center=0,
        cbar_kws={"label": "Mean ARI Improvement"},
    )

    plt.title(
        "ARI Improvement by Model and Configuration", fontsize=16, fontweight="bold"
    )
    plt.xlabel("Configuration", fontsize=12)
    plt.ylabel("Model", fontsize=12)
    plt.tight_layout()
    plt.savefig(plots_dir / "ari_improvement_heatmap.png", dpi=300, bbox_inches="tight")
    plt.close()


def create_gibbs_performance_plot_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create plots showing performance vs Gibbs iterations for a specific model."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(
        f"Performance vs Gibbs Iterations - {model_name}",
        fontsize=16,
        fontweight="bold",
    )

    metrics = ["ari_original", "ari_denoised", "ari_improvement"]

    # Box plot by Gibbs iterations
    for i, metric in enumerate(metrics):
        ax = axes[i // 2, i % 2]
        sns.boxplot(data=df, x="gibbs_iterations", y=metric, ax=ax)
        ax.set_title(f'{metric.replace("_", " ").title()}')
        ax.set_xlabel("Gibbs Iterations")
        ax.tick_params(axis="x", rotation=45)

    # Line plot showing mean performance by config
    ax = axes[1, 1]
    for config in df["config"].unique():
        config_data = (
            df[df["config"] == config]
            .groupby("gibbs_iterations")["ari_improvement"]
            .mean()
        )
        ax.plot(
            config_data.index, config_data.values, marker="o", label=config, linewidth=2
        )

    ax.set_title("Mean ARI Improvement by Configuration")
    ax.set_xlabel("Gibbs Iterations")
    ax.set_ylabel("Mean ARI Improvement")
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    _handle_plot_display(plots_dir, "gibbs_performance.png", visualize)


def create_gibbs_performance_plot(df: pd.DataFrame, plots_dir: Path):
    """Create plots showing performance vs Gibbs iterations."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle("Performance vs Gibbs Iterations", fontsize=16, fontweight="bold")

    metrics = ["ari_original", "ari_denoised", "ari_improvement"]

    # Box plot by Gibbs iterations
    for i, metric in enumerate(metrics):
        ax = axes[i // 2, i % 2]
        sns.boxplot(data=df, x="gibbs_iterations", y=metric, ax=ax)
        ax.set_title(f'{metric.replace("_", " ").title()}')
        ax.set_xlabel("Gibbs Iterations")
        ax.tick_params(axis="x", rotation=45)

    # Line plot showing mean performance by config
    ax = axes[1, 1]
    for config in df["config"].unique():
        config_data = (
            df[df["config"] == config]
            .groupby("gibbs_iterations")["ari_improvement"]
            .mean()
        )
        ax.plot(
            config_data.index, config_data.values, marker="o", label=config, linewidth=2
        )

    ax.set_title("Mean ARI Improvement by Configuration")
    ax.set_xlabel("Gibbs Iterations")
    ax.set_ylabel("Mean ARI Improvement")
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(plots_dir / "gibbs_performance.png", dpi=300, bbox_inches="tight")
    plt.close()


def create_rho_performance_plot_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create plots showing performance vs Rho parameters for a specific model."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(
        f"Performance vs Rho Parameters - {model_name}", fontsize=16, fontweight="bold"
    )

    # Create rho_pair column for easier grouping
    df["rho_pair"] = df["rho_start"].astype(str) + "-" + df["rho_end"].astype(str)

    # Box plot by rho pairs
    sns.boxplot(data=df, x="rho_pair", y="ari_improvement", ax=axes[0, 0])
    axes[0, 0].set_title("ARI Improvement by Rho Pairs")
    axes[0, 0].tick_params(axis="x", rotation=45)

    # Scatter plot: rho_start vs performance
    sns.scatterplot(
        data=df,
        x="rho_start",
        y="ari_improvement",
        hue="config",
        size="gibbs_iterations",
        ax=axes[0, 1],
    )
    axes[0, 1].set_title("ARI Improvement vs Rho Start")

    # Scatter plot: rho_end vs performance
    sns.scatterplot(
        data=df,
        x="rho_end",
        y="ari_improvement",
        hue="config",
        size="gibbs_iterations",
        ax=axes[1, 0],
    )
    axes[1, 0].set_title("ARI Improvement vs Rho End")

    # Heatmap: rho_start vs rho_end
    pivot_rho = (
        df.groupby(["rho_start", "rho_end"])["ari_improvement"].mean().reset_index()
    )
    pivot_rho_table = pivot_rho.pivot(
        index="rho_start", columns="rho_end", values="ari_improvement"
    )
    sns.heatmap(pivot_rho_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[1, 1])
    axes[1, 1].set_title("Mean ARI Improvement: Rho Start vs Rho End")

    plt.tight_layout()
    _handle_plot_display(plots_dir, "rho_performance.png", visualize)
    plt.close()


def create_rho_performance_plot(
    df: pd.DataFrame, plots_dir: Path, visualize: bool = False
):
    """Create plots showing performance vs Rho parameters."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle("Performance vs Rho Parameters", fontsize=16, fontweight="bold")

    # Create rho_pair column for easier grouping
    df["rho_pair"] = df["rho_start"].astype(str) + "-" + df["rho_end"].astype(str)

    # Box plot by rho pairs
    sns.boxplot(data=df, x="rho_pair", y="ari_improvement", ax=axes[0, 0])
    axes[0, 0].set_title("ARI Improvement by Rho Pairs")
    axes[0, 0].tick_params(axis="x", rotation=45)

    # Scatter plot: rho_start vs performance
    sns.scatterplot(
        data=df,
        x="rho_start",
        y="ari_improvement",
        hue="config",
        size="gibbs_iterations",
        ax=axes[0, 1],
    )
    axes[0, 1].set_title("ARI Improvement vs Rho Start")

    # Scatter plot: rho_end vs performance
    sns.scatterplot(
        data=df,
        x="rho_end",
        y="ari_improvement",
        hue="config",
        size="gibbs_iterations",
        ax=axes[1, 0],
    )
    axes[1, 0].set_title("ARI Improvement vs Rho End")

    # Heatmap: rho_start vs rho_end
    pivot_rho = (
        df.groupby(["rho_start", "rho_end"])["ari_improvement"].mean().reset_index()
    )
    pivot_rho_table = pivot_rho.pivot(
        index="rho_start", columns="rho_end", values="ari_improvement"
    )
    sns.heatmap(pivot_rho_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[1, 1])
    axes[1, 1].set_title("Mean ARI Improvement: Rho Start vs Rho End")

    plt.tight_layout()
    plt.savefig(plots_dir / "rho_performance.png", dpi=300, bbox_inches="tight")
    plt.close()


def create_sampling_performance_plot_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create plots showing performance vs repeated sampling factor for a specific model."""
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    fig.suptitle(
        f"Performance vs Repeated Sampling Factor - {model_name}",
        fontsize=16,
        fontweight="bold",
    )

    # Box plot
    sns.boxplot(data=df, x="repeated_sampling_factor", y="ari_improvement", ax=axes[0])
    axes[0].set_title("ARI Improvement Distribution")
    axes[0].set_xlabel("Repeated Sampling Factor")

    # Line plot by config
    for config in df["config"].unique():
        config_data = (
            df[df["config"] == config]
            .groupby("repeated_sampling_factor")["ari_improvement"]
            .mean()
        )
        axes[1].plot(
            config_data.index, config_data.values, marker="o", label=config, linewidth=2
        )

    axes[1].set_title("Mean ARI Improvement by Configuration")
    axes[1].set_xlabel("Repeated Sampling Factor")
    axes[1].set_ylabel("Mean ARI Improvement")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    _handle_plot_display(plots_dir, "sampling_performance.png", visualize)
    plt.close()


def create_sampling_performance_plot(df: pd.DataFrame, plots_dir: Path):
    """Create plots showing performance vs repeated sampling factor."""
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    fig.suptitle(
        "Performance vs Repeated Sampling Factor", fontsize=16, fontweight="bold"
    )

    # Box plot
    sns.boxplot(data=df, x="repeated_sampling_factor", y="ari_improvement", ax=axes[0])
    axes[0].set_title("ARI Improvement Distribution")
    axes[0].set_xlabel("Repeated Sampling Factor")

    # Line plot by config
    for config in df["config"].unique():
        config_data = (
            df[df["config"] == config]
            .groupby("repeated_sampling_factor")["ari_improvement"]
            .mean()
        )
        axes[1].plot(
            config_data.index, config_data.values, marker="o", label=config, linewidth=2
        )

    axes[1].set_title("Mean ARI Improvement by Configuration")
    axes[1].set_xlabel("Repeated Sampling Factor")
    axes[1].set_ylabel("Mean ARI Improvement")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(plots_dir / "sampling_performance.png", dpi=300, bbox_inches="tight")
    plt.close()


def create_performance_distribution_plot_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create plots showing overall performance distributions for a specific model."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(
        f"Performance Distributions - {model_name}", fontsize=16, fontweight="bold"
    )

    # Histogram of ARI improvements
    axes[0, 0].hist(
        df["ari_improvement"], bins=20, alpha=0.7, color="skyblue", edgecolor="black"
    )
    axes[0, 0].axvline(
        df["ari_improvement"].mean(),
        color="red",
        linestyle="--",
        label=f'Mean: {df["ari_improvement"].mean():.4f}',
    )
    axes[0, 0].set_title("Distribution of ARI Improvements")
    axes[0, 0].set_xlabel("ARI Improvement")
    axes[0, 0].set_ylabel("Frequency")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Box plot by config
    sns.boxplot(data=df, x="config", y="ari_improvement", ax=axes[0, 1])
    axes[0, 1].set_title("ARI Improvement by Configuration")
    axes[0, 1].tick_params(axis="x", rotation=45)

    # Violin plot by config
    sns.violinplot(data=df, x="config", y="ari_improvement", ax=axes[1, 0])
    axes[1, 0].set_title("ARI Improvement Distribution by Configuration")
    axes[1, 0].tick_params(axis="x", rotation=45)

    # Correlation heatmap
    numeric_cols = [
        "gibbs_iterations",
        "rho_start",
        "rho_end",
        "repeated_sampling_factor",
        "ari_original",
        "ari_denoised",
        "ari_improvement",
    ]
    corr_matrix = df[numeric_cols].corr()
    sns.heatmap(
        corr_matrix, annot=True, fmt=".3f", cmap="RdBu_r", center=0, ax=axes[1, 1]
    )
    axes[1, 1].set_title("Parameter Correlation Matrix")

    plt.tight_layout()
    _handle_plot_display(plots_dir, "performance_distributions.png", visualize)
    plt.close()


def create_performance_distribution_plot(df: pd.DataFrame, plots_dir: Path):
    """Create plots showing overall performance distributions."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle("Performance Distributions", fontsize=16, fontweight="bold")

    # Histogram of ARI improvements
    axes[0, 0].hist(
        df["ari_improvement"], bins=30, alpha=0.7, color="skyblue", edgecolor="black"
    )
    axes[0, 0].axvline(
        df["ari_improvement"].mean(),
        color="red",
        linestyle="--",
        label=f'Mean: {df["ari_improvement"].mean():.4f}',
    )
    axes[0, 0].set_title("Distribution of ARI Improvements")
    axes[0, 0].set_xlabel("ARI Improvement")
    axes[0, 0].set_ylabel("Frequency")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Box plot by model
    sns.boxplot(data=df, x="model", y="ari_improvement", ax=axes[0, 1])
    axes[0, 1].set_title("ARI Improvement by Model")
    axes[0, 1].tick_params(axis="x", rotation=45)

    # Violin plot by config
    sns.violinplot(data=df, x="config", y="ari_improvement", ax=axes[1, 0])
    axes[1, 0].set_title("ARI Improvement Distribution by Configuration")
    axes[1, 0].tick_params(axis="x", rotation=45)

    # Correlation heatmap
    numeric_cols = [
        "gibbs_iterations",
        "rho_start",
        "rho_end",
        "repeated_sampling_factor",
        "ari_original",
        "ari_denoised",
        "ari_improvement",
    ]
    corr_matrix = df[numeric_cols].corr()
    sns.heatmap(
        corr_matrix, annot=True, fmt=".3f", cmap="RdBu_r", center=0, ax=axes[1, 1]
    )
    axes[1, 1].set_title("Parameter Correlation Matrix")

    plt.tight_layout()
    plt.savefig(
        plots_dir / "performance_distributions.png", dpi=300, bbox_inches="tight"
    )
    plt.close()


def create_best_config_comparison_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create bar plot comparing top configurations for a specific model."""
    # Get top 10 configurations by mean ARI improvement for this model
    top_configs = (
        df.groupby(
            [
                "config",
                "gibbs_iterations",
                "rho_start",
                "rho_end",
                "repeated_sampling_factor",
            ]
        )["ari_improvement"]
        .mean()
        .nlargest(10)
        .reset_index()
    )

    # Create labels for x-axis
    top_configs["config_label"] = (
        top_configs["config"]
        + "_G"
        + top_configs["gibbs_iterations"].astype(str)
        + "_R"
        + top_configs["rho_start"].astype(str)
        + "-"
        + top_configs["rho_end"].astype(str)
        + "_S"
        + top_configs["repeated_sampling_factor"].astype(str)
    )

    plt.figure(figsize=(15, 8))
    bars = plt.bar(
        range(len(top_configs)),
        top_configs["ari_improvement"],
        color=plt.cm.viridis(np.linspace(0, 1, len(top_configs))),
    )

    plt.title(
        f"Top 10 Configurations by ARI Improvement - {model_name}",
        fontsize=16,
        fontweight="bold",
    )
    plt.xlabel("Configuration", fontsize=12)
    plt.ylabel("Mean ARI Improvement", fontsize=12)
    plt.xticks(
        range(len(top_configs)), top_configs["config_label"], rotation=45, ha="right"
    )

    # Add value labels on bars
    for i, (bar, value) in enumerate(zip(bars, top_configs["ari_improvement"])):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.001,
            f"{value:.4f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    plt.grid(True, alpha=0.3, axis="y")
    plt.tight_layout()
    _handle_plot_display(plots_dir, "best_configs_comparison.png", visualize)
    plt.close()


def create_best_config_comparison(df: pd.DataFrame, plots_dir: Path):
    """Create bar plot comparing top configurations."""
    # Get top 10 configurations by mean ARI improvement
    top_configs = (
        df.groupby(
            [
                "model",
                "config",
                "gibbs_iterations",
                "rho_start",
                "rho_end",
                "repeated_sampling_factor",
            ]
        )["ari_improvement"]
        .mean()
        .nlargest(10)
        .reset_index()
    )

    # Create labels for x-axis
    top_configs["config_label"] = (
        top_configs["model"]
        + "_"
        + top_configs["config"]
        + "_G"
        + top_configs["gibbs_iterations"].astype(str)
        + "_R"
        + top_configs["rho_start"].astype(str)
        + "-"
        + top_configs["rho_end"].astype(str)
        + "_S"
        + top_configs["repeated_sampling_factor"].astype(str)
    )

    plt.figure(figsize=(15, 8))
    bars = plt.bar(
        range(len(top_configs)),
        top_configs["ari_improvement"],
        color=plt.cm.viridis(np.linspace(0, 1, len(top_configs))),
    )

    plt.title(
        "Top 10 Configurations by ARI Improvement", fontsize=16, fontweight="bold"
    )
    plt.xlabel("Configuration", fontsize=12)
    plt.ylabel("Mean ARI Improvement", fontsize=12)
    plt.xticks(
        range(len(top_configs)), top_configs["config_label"], rotation=45, ha="right"
    )

    # Add value labels on bars
    for i, (bar, value) in enumerate(zip(bars, top_configs["ari_improvement"])):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.001,
            f"{value:.4f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    plt.grid(True, alpha=0.3, axis="y")
    plt.tight_layout()
    plt.savefig(plots_dir / "best_configs_comparison.png", dpi=300, bbox_inches="tight")
    plt.close()


def create_parameter_interaction_plot_per_model(
    df: pd.DataFrame, plots_dir: Path, model_name: str, visualize: bool = False
):
    """Create plots showing parameter interaction effects for a specific model."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(
        f"Parameter Interaction Effects - {model_name}", fontsize=16, fontweight="bold"
    )

    # Gibbs vs Rho Start interaction
    pivot1 = (
        df.groupby(["gibbs_iterations", "rho_start"])["ari_improvement"]
        .mean()
        .reset_index()
    )
    pivot1_table = pivot1.pivot(
        index="gibbs_iterations", columns="rho_start", values="ari_improvement"
    )
    sns.heatmap(pivot1_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[0, 0])
    axes[0, 0].set_title("Gibbs Iterations vs Rho Start")

    # Config vs Sampling Factor interaction
    pivot2 = (
        df.groupby(["config", "repeated_sampling_factor"])["ari_improvement"]
        .mean()
        .reset_index()
    )
    pivot2_table = pivot2.pivot(
        index="config", columns="repeated_sampling_factor", values="ari_improvement"
    )
    sns.heatmap(pivot2_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[0, 1])
    axes[0, 1].set_title("Configuration vs Repeated Sampling Factor")

    # Gibbs vs Config interaction
    pivot3 = (
        df.groupby(["config", "gibbs_iterations"])["ari_improvement"]
        .mean()
        .reset_index()
    )
    pivot3_table = pivot3.pivot(
        index="config", columns="gibbs_iterations", values="ari_improvement"
    )
    sns.heatmap(pivot3_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[1, 0])
    axes[1, 0].set_title("Configuration vs Gibbs Iterations")

    # Rho Start vs Rho End vs Config (using scatter plot with size)
    sns.scatterplot(
        data=df,
        x="rho_start",
        y="rho_end",
        size="ari_improvement",
        hue="config",
        sizes=(50, 200),
        ax=axes[1, 1],
    )
    axes[1, 1].set_title("Rho Parameters Interaction (size = ARI improvement)")

    plt.tight_layout()
    plt.savefig(plots_dir / "parameter_interactions.png", dpi=300, bbox_inches="tight")
    plt.close()


def create_parameter_interaction_plot(df: pd.DataFrame, plots_dir: Path):
    """Create plots showing parameter interaction effects."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle("Parameter Interaction Effects", fontsize=16, fontweight="bold")

    # Gibbs vs Rho Start interaction
    pivot1 = (
        df.groupby(["gibbs_iterations", "rho_start"])["ari_improvement"]
        .mean()
        .reset_index()
    )
    pivot1_table = pivot1.pivot(
        index="gibbs_iterations", columns="rho_start", values="ari_improvement"
    )
    sns.heatmap(pivot1_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[0, 0])
    axes[0, 0].set_title("Gibbs Iterations vs Rho Start")

    # Config vs Sampling Factor interaction
    pivot2 = (
        df.groupby(["config", "repeated_sampling_factor"])["ari_improvement"]
        .mean()
        .reset_index()
    )
    pivot2_table = pivot2.pivot(
        index="config", columns="repeated_sampling_factor", values="ari_improvement"
    )
    sns.heatmap(pivot2_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[0, 1])
    axes[0, 1].set_title("Configuration vs Repeated Sampling Factor")

    # Model vs Gibbs interaction
    pivot3 = (
        df.groupby(["model", "gibbs_iterations"])["ari_improvement"]
        .mean()
        .reset_index()
    )
    pivot3_table = pivot3.pivot(
        index="model", columns="gibbs_iterations", values="ari_improvement"
    )
    sns.heatmap(pivot3_table, annot=True, fmt=".4f", cmap="RdYlBu_r", ax=axes[1, 0])
    axes[1, 0].set_title("Model vs Gibbs Iterations")

    # Rho Start vs Rho End vs Config (using scatter plot with size)
    sns.scatterplot(
        data=df,
        x="rho_start",
        y="rho_end",
        size="ari_improvement",
        hue="config",
        sizes=(50, 200),
        ax=axes[1, 1],
    )
    axes[1, 1].set_title("Rho Parameters Interaction (size = ARI improvement)")

    plt.tight_layout()
    plt.savefig(plots_dir / "parameter_interactions.png", dpi=300, bbox_inches="tight")
    plt.close()

    plt.close()


def create_cross_model_comparison(
    df: pd.DataFrame, plots_dir: Path, visualize: bool = False
):
    """Create comparison plots between models (but clearly separated)."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(
        "Cross-Model Comparison (Models Analyzed Separately)",
        fontsize=16,
        fontweight="bold",
    )

    # Overall performance comparison by model
    sns.boxplot(data=df, x="model", y="ari_improvement", ax=axes[0, 0])
    axes[0, 0].set_title("ARI Improvement Distribution by Model")
    axes[0, 0].tick_params(axis="x", rotation=45)

    # Best configuration for each model
    best_by_model = (
        df.groupby(["model", "config"])["ari_improvement"].mean().reset_index()
    )
    sns.barplot(
        data=best_by_model, x="model", y="ari_improvement", hue="config", ax=axes[0, 1]
    )
    axes[0, 1].set_title("Mean Performance by Model and Configuration")
    axes[0, 1].tick_params(axis="x", rotation=45)
    axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    # Performance range by model
    model_stats = (
        df.groupby("model")["ari_improvement"]
        .agg(["mean", "std", "min", "max"])
        .reset_index()
    )
    x_pos = range(len(model_stats))
    axes[1, 0].bar(
        x_pos,
        model_stats["mean"],
        yerr=model_stats["std"],
        capsize=5,
        alpha=0.7,
        color=["skyblue", "lightcoral"],
    )
    axes[1, 0].set_xticks(x_pos)
    axes[1, 0].set_xticklabels(model_stats["model"], rotation=45)
    axes[1, 0].set_title("Mean ARI Improvement ± Std by Model")
    axes[1, 0].set_ylabel("ARI Improvement")

    # Parameter sensitivity comparison
    for model in df["model"].unique():
        model_data = df[df["model"] == model]
        gibbs_effect = model_data.groupby("gibbs_iterations")["ari_improvement"].mean()
        axes[1, 1].plot(
            gibbs_effect.index,
            gibbs_effect.values,
            marker="o",
            label=model,
            linewidth=2,
        )

    axes[1, 1].set_title("Gibbs Iterations Effect by Model")
    axes[1, 1].set_xlabel("Gibbs Iterations")
    axes[1, 1].set_ylabel("Mean ARI Improvement")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    _handle_plot_display(plots_dir, "cross_model_comparison.png", visualize)
    plt.close()


def print_quick_summary(df: pd.DataFrame):
    """Print a quick summary of the results to console."""
    print("\n" + "=" * 80)
    print("GRID SEARCH ANALYSIS SUMMARY")
    print("=" * 80)

    print(f"\nTotal experiments found: {len(df)}")
    print(f"Models analyzed: {df['model'].unique()}")
    print(f"Configurations tested: {df['config'].unique()}")
    print(f"Gibbs iterations tested: {sorted(df['gibbs_iterations'].unique())}")
    print(
        f"Rho value pairs tested: {sorted(df[['rho_start', 'rho_end']].drop_duplicates().values.tolist())}"
    )
    print(
        f"Repeated sampling factors tested: {sorted(df['repeated_sampling_factor'].unique())}"
    )

    # Calculate and display statistics for all available metrics
    print(f"\nOverall Metrics Statistics:")

    # ARI metrics first (if available)
    if "ari_original" in df.columns and "ari_denoised" in df.columns:
        print(
            f"ARI Original - Mean: {df['ari_original'].mean():.4f}, Std: {df['ari_original'].std():.4f}"
        )
        print(
            f"ARI Denoised - Mean: {df['ari_denoised'].mean():.4f}, Std: {df['ari_denoised'].std():.4f}"
        )
        if "ari_improvement" in df.columns:
            print(
                f"ARI Improvement - Mean: {df['ari_improvement'].mean():.4f}, Std: {df['ari_improvement'].std():.4f}"
            )

    # Other metrics from CSV
    metric_pairs = {}
    for col in df.columns:
        if col.startswith("true_"):
            metric_name = col[5:]  # Remove 'true_' prefix
            denoised_col = f"denoised_{metric_name}"
            if denoised_col in df.columns:
                metric_pairs[metric_name] = (col, denoised_col)

    for metric_name, (true_col, denoised_col) in metric_pairs.items():
        if metric_name != "adjusted_rand_index":  # Skip ARI as it's handled above
            print(f"{metric_name.replace('_', ' ').title()}:")
            print(
                f"  True - Mean: {df[true_col].mean():.4f}, Std: {df[true_col].std():.4f}"
            )
            print(
                f"  Denoised - Mean: {df[denoised_col].mean():.4f}, Std: {df[denoised_col].std():.4f}"
            )
            # Calculate improvement
            improvement = df[denoised_col] - df[true_col]
            print(
                f"  Improvement - Mean: {improvement.mean():.4f}, Std: {improvement.std():.4f}"
            )

    # Show silhouette benchmark if available
    if "denoised_silhouette_benchmark" in df.columns:
        print(
            f"Silhouette Benchmark - Mean: {df['denoised_silhouette_benchmark'].mean():.4f}, Std: {df['denoised_silhouette_benchmark'].std():.4f}"
        )

    # Best single result (based on ARI improvement)
    best_idx = df["ari_improvement"].idxmax()
    best_result = df.loc[best_idx]

    print(f"\nBest Single Result (highest ARI improvement):")
    print(f"Model: {best_result['model']}, Config: {best_result['config']}")
    print(
        f"Gibbs: {best_result['gibbs_iterations']}, Rho: {best_result['rho_start']}-{best_result['rho_end']}, Repeated Sampling Factor: {best_result['repeated_sampling_factor']}"
    )

    # Show all available metrics for the best result
    print(f"ARI Original: {best_result['ari_original']:.4f}")
    print(f"ARI Denoised: {best_result['ari_denoised']:.4f}")
    print(f"ARI Improvement: {best_result['ari_improvement']:.4f}")

    # Show all other metrics
    for col in df.columns:
        if col.startswith("true_"):
            metric_name = col[5:]  # Remove 'true_' prefix
            denoised_col = f"denoised_{metric_name}"
            if (
                denoised_col in df.columns and metric_name != "adjusted_rand_index"
            ):  # Skip ARI as it's already shown
                improvement_col = f"{metric_name}_improvement"
                if improvement_col in df.columns:
                    print(
                        f"{metric_name.replace('_', ' ').title()} Original: {best_result[col]:.4f}"
                    )
                    print(
                        f"{metric_name.replace('_', ' ').title()} Denoised: {best_result[denoised_col]:.4f}"
                    )
                    print(
                        f"{metric_name.replace('_', ' ').title()} Improvement: {best_result[improvement_col]:.4f}"
                    )

    # Show silhouette benchmark if available
    if "denoised_silhouette_benchmark" in df.columns:
        print(
            f"Silhouette Benchmark: {best_result['denoised_silhouette_benchmark']:.4f}"
        )


def main():
    parser = argparse.ArgumentParser(
        description="Analyze grid search results from mixed diffusion experiments"
    )
    parser.add_argument(
        "--results_dir", type=str, default="results", help="Path to results directory"
    )
    parser.add_argument(
        "--models",
        nargs="+",
        default=["", "synthetic_easy"],
        help="List of model names to analyze",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="analysis_results",
        help="Directory to save analysis results",
    )
    parser.add_argument(
        "--metric",
        type=str,
        default="ari_improvement",
        choices=["ari_improvement", "ari_denoised", "ari_original"],
        help="Metric to use for finding best configurations",
    )
    parser.add_argument(
        "--plot",
        action="store_true",
        help="Generate performance visualization plots",
    )
    parser.add_argument(
        "--no-plot",
        action="store_true",
        help="Skip generating plots (useful when matplotlib is not available)",
    )
    parser.add_argument(
        "--visualize",
        action="store_true",
        help="Display plots interactively instead of just saving them",
    )

    args = parser.parse_args()

    # Convert paths to Path objects
    results_dir = Path(args.results_dir)
    output_dir = Path(args.output_dir)

    if not results_dir.exists():
        print(f"Results directory {results_dir} does not exist!")
        return

    print(f"Analyzing results in: {results_dir}")
    print(f"Models to analyze: {args.models}")

    # Scan all results
    all_results = scan_results_directory(results_dir, args.models)

    if not all_results:
        print("No results found!")
        return

    # Create DataFrame
    df = pd.DataFrame(all_results)

    # Generate summary statistics
    summary = create_summary_statistics(df)

    # Find best configurations
    best_configs = analyze_best_configurations(df, args.metric)

    # Print quick summary to console
    print_quick_summary(df)

    # Save all results
    save_analysis_results(df, summary, best_configs, output_dir)

    # Generate plots if requested and matplotlib is available
    if args.plot and not args.no_plot:
        try:
            create_performance_plots(df, output_dir, args.visualize)
        except ImportError as e:
            print(f"⚠️  Could not create plots: {e}")
            print(
                "Install matplotlib and seaborn to enable plotting: pip install matplotlib seaborn"
            )
    elif not args.no_plot:
        # Default behavior: try to create plots unless explicitly disabled
        try:
            create_performance_plots(df, output_dir, args.visualize)
        except ImportError:
            print("📊 Skipping plots (matplotlib/seaborn not available)")

    print(f"\n✅ Analysis complete! Results saved to: {output_dir}")

    # Print top 5 configurations
    print(f"\nTop 5 configurations by {args.metric}:")
    print("=" * 80)
    for i, (_, row) in enumerate(best_configs.head().iterrows()):
        print(f"{i+1}. Model: {row['model']}, Config: {row['config']}")
        print(
            f"   Gibbs: {row['gibbs_iterations']}, Rho: {row['rho_start']}-{row['rho_end']}, Repeated Sampling Factor: {row['repeated_sampling_factor']}"
        )
        print(f"   Avg ARI Original: {row['avg_ari_original']:.4f}")
        print(f"   Avg ARI Denoised: {row['avg_ari_denoised']:.4f}")
        print(f"   Avg ARI Improvement: {row['avg_ari_improvement']:.4f}")
        print(f"   Number of runs: {row['num_runs']}")
        print()


if __name__ == "__main__":
    main()
