import json
import os
from typing import List, Optional

from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from scripts.generate_prompts import SIZE_PATTERNS


def create_benchmark_visualizations(
    evaluation_file: str,
    output_dir: str = "evaluation_data/visualizations",
    models: Optional[List[str]] = None,
    system_prompts: Optional[List[str]] = None,
    patterns: Optional[List[str]] = None,
    min_sample_size: int = 1,  # Reduced default to ensure visualizations can be generated
):
    """
    Creates detailed visualizations for each benchmark based on evaluation results.

    Parameters:
    - evaluation_file (str): Path to the evaluation results JSON file
    - output_dir (str): Directory to save visualization images
    - models (list, optional): List of model names to include in visualizations
    - system_prompts (list, optional): List of system prompts to include
    - patterns (list, optional): List of size patterns to include
    - min_sample_size (int): Minimum number of samples required for inclusion in charts
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Load evaluation results
    with open(evaluation_file, "r", encoding="utf-8") as f:
        evaluation_results = json.load(f)

    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame(evaluation_results)

    # Apply filters if specified
    if models:
        df = df[df["model"].isin(models)]
        if len(df) == 0:
            print(f"⚠️ No data found for the specified models: {', '.join(models)}")
            return
        print(f"Creating visualizations for models: {', '.join(df['model'].unique())}")
    else:
        print(
            f"Creating visualizations for all models: {', '.join(df['model'].unique())}"
        )

    if system_prompts:
        df = df[df["system_prompt"].isin(system_prompts)]
        if len(df) == 0:
            print(
                f"⚠️ No data found for the specified system prompts: {', '.join(system_prompts)}"
            )
            return

    if patterns:
        df = df[df["size_pattern"].isin(patterns)]
        if len(df) == 0:
            print(f"⚠️ No data found for the specified patterns: {', '.join(patterns)}")
            return

    # Get unique values
    benchmarks = df["benchmark"].unique()

    # 1. Create accuracy heatmaps by benchmark and model
    print(f"Creating benchmark visualizations in {output_dir}...")

    # 1. Per-benchmark plots
    for benchmark in benchmarks:
        benchmark_df = df[
            df["benchmark"] == benchmark
        ].copy()  # Use copy to avoid SettingWithCopyWarning

        # Create directory for this benchmark
        benchmark_dir = os.path.join(output_dir, benchmark)
        os.makedirs(benchmark_dir, exist_ok=True)

        # 1.1. Overall accuracy by size_pattern, graph_type, system_prompt and model
        plt.figure(figsize=(16, 12))

        # Combine size_pattern and system_prompt for more compact display
        benchmark_df["pattern_system"] = (
            benchmark_df["size_pattern"] + " | " + benchmark_df["system_prompt"]
        )

        # Count samples in each group
        group_counts = (
            benchmark_df.groupby(["pattern_system", "graph_type", "model"])
            .size()
            .reset_index(name="count")
        )

        # Filter groups with enough samples
        valid_groups = group_counts[group_counts["count"] >= min_sample_size]

        # Create heatmap data
        heatmap_data = benchmark_df.pivot_table(
            index=["pattern_system", "graph_type"],
            columns="model",
            values="correct",
            aggfunc="mean",
        )

        # Handle the case where no groups have enough samples
        if valid_groups.empty:
            print(
                f"⚠️ No groups have enough samples (min={min_sample_size}) for benchmark {benchmark}. Skipping detailed heatmap."
            )
        else:
            # Create valid indices from groups that have enough samples
            valid_indices = pd.MultiIndex.from_tuples(
                [
                    (p, g)
                    for p, g in zip(
                        valid_groups["pattern_system"], valid_groups["graph_type"]
                    )
                ]
            )

            # Filter heatmap data to only include valid groups
            filtered_heatmap_data = heatmap_data.loc[
                heatmap_data.index.isin(valid_indices)
            ]

            # Create heatmap with sample size annotations
            if not filtered_heatmap_data.empty:
                sns.heatmap(
                    filtered_heatmap_data,
                    annot=True,
                    cmap="YlGnBu",
                    vmin=0,
                    vmax=1,
                    fmt=".2f",
                )
                plt.title(
                    f"{benchmark} - Accuracy by Pattern, System Prompt, Graph Type and Model"
                )
                plt.tight_layout()
                plt.savefig(
                    os.path.join(benchmark_dir, f"{benchmark}_accuracy_heatmap.png"),
                    dpi=300,
                )
        plt.close()

        # 1.2. Accuracy by encoding, system_prompt and model - simplified to ensure it works
        plt.figure(figsize=(14, 10))

        encoding_system_data = benchmark_df.pivot_table(
            index="encoding",
            columns="model",
            values="correct",
            aggfunc="mean",
        )

        if not encoding_system_data.empty:
            sns.heatmap(
                encoding_system_data,
                annot=True,
                cmap="YlGnBu",
                vmin=0,
                vmax=1,
                fmt=".2f",
            )
            plt.title(f"{benchmark} - Accuracy by Encoding and Model")
            plt.tight_layout()
            plt.savefig(
                os.path.join(benchmark_dir, f"{benchmark}_encoding.png"), dpi=300
            )
        plt.close()

        # 1.3. Bar chart comparing performance across graph types for each model
        plt.figure(figsize=(14, 8))

        graph_type_data = (
            benchmark_df.groupby(["model", "graph_type"])["correct"].mean().unstack()
        )

        if not graph_type_data.empty:
            ax = graph_type_data.plot(kind="bar", ylim=[0, 1])
            plt.title(f"{benchmark} - Model Performance by Graph Type")
            plt.ylabel("Accuracy")
            plt.xlabel("Model")

            # Move legend outside plot area
            plt.legend(title="Graph Type", bbox_to_anchor=(1.05, 1), loc="upper left")

            plt.tight_layout()
            plt.savefig(
                os.path.join(benchmark_dir, f"{benchmark}_graph_type_performance.png"),
                dpi=300,
            )
        plt.close()

        # 1.4. Bar chart comparing system prompts for each model
        plt.figure(figsize=(14, 8))

        system_data = (
            benchmark_df.groupby(["model", "system_prompt"])["correct"].mean().unstack()
        )

        if (
            not system_data.empty and system_data.shape[1] > 1
        ):  # Only if we have multiple system prompts
            ax = system_data.plot(kind="bar", ylim=[0, 1])
            plt.title(f"{benchmark} - Model Performance by System Prompt")
            plt.ylabel("Accuracy")
            plt.xlabel("Model")

            plt.legend(
                title="System Prompt", bbox_to_anchor=(1.05, 1), loc="upper left"
            )

            plt.tight_layout()
            plt.savefig(
                os.path.join(
                    benchmark_dir, f"{benchmark}_system_prompt_performance.png"
                ),
                dpi=300,
            )
        plt.close()

        # 1.5. Bar chart comparing size patterns for each model
        plt.figure(figsize=(14, 8))

        pattern_data = (
            benchmark_df.groupby(["model", "size_pattern"])["correct"].mean().unstack()
        )

        if (
            not pattern_data.empty and pattern_data.shape[1] > 1
        ):  # Only if we have multiple patterns
            ax = pattern_data.plot(kind="bar", ylim=[0, 1])
            plt.title(f"{benchmark} - Model Performance by Size Pattern")
            plt.ylabel("Accuracy")
            plt.xlabel("Model")

            plt.legend(title="Size Pattern", bbox_to_anchor=(1.05, 1), loc="upper left")

            plt.tight_layout()
            plt.savefig(
                os.path.join(benchmark_dir, f"{benchmark}_pattern_performance.png"),
                dpi=300,
            )
        plt.close()

    # 2. Overall model performance across all benchmarks
    print("Creating model performance visualizations...")

    # 2.1. Model performance by benchmark
    plt.figure(figsize=(16, 10))

    model_benchmark_data = df.pivot_table(
        index="model", columns="benchmark", values="correct", aggfunc="mean"
    )

    # Sort benchmarks by average difficulty
    benchmark_difficulty = df.groupby("benchmark")["correct"].mean().sort_values()
    model_benchmark_data = model_benchmark_data[benchmark_difficulty.index]

    if not model_benchmark_data.empty:
        ax = model_benchmark_data.plot(kind="bar", figsize=(16, 10), width=0.8)
        plt.title("Model Performance Across Benchmarks")
        plt.ylabel("Accuracy")
        plt.xlabel("Model")

        # Create legend with sample counts
        handles, labels = ax.get_legend_handles_labels()
        new_labels = []

        # Count distinct examples for each benchmark
        benchmark_sample_counts = {}
        for benchmark in df["benchmark"].unique():
            unique_combinations = (
                df[df["benchmark"] == benchmark]
                .groupby(["graph_type", "size_pattern", "system_prompt", "encoding"])
                .size()
                .shape[0]
            )
            benchmark_sample_counts[benchmark] = unique_combinations

        for label in labels:
            samples = benchmark_sample_counts.get(label, 0)
            new_label = f"{label} (n={samples})"
            new_labels.append(new_label)

        # Calculate total samples
        total_samples = sum(benchmark_sample_counts.values())

        # Add a separator and total to the legend
        dummy = Patch(fill=False, edgecolor="none", visible=False)
        total_patch = Patch(fill=False, edgecolor="none")

        all_handles = list(handles) + [dummy, total_patch]
        all_labels = new_labels + ["──────────────", f"Total samples: {total_samples}"]

        plt.legend(
            all_handles,
            all_labels,
            title="Benchmark",
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, "model_benchmark_performance.png"), dpi=300
        )
    plt.close()

    # 2.2. System prompt impact across models
    plt.figure(figsize=(14, 8))

    system_impact = df.pivot_table(
        index="model", columns="system_prompt", values="correct", aggfunc="mean"
    )

    if (
        not system_impact.empty and system_impact.shape[1] > 1
    ):  # Only if we have multiple system prompts
        ax = system_impact.plot(kind="bar", figsize=(14, 8), ylim=[0, 1])
        plt.title("Impact of System Prompts on Model Performance")
        plt.ylabel("Accuracy")
        plt.grid(True, axis="y", linestyle="--", alpha=0.7)

        # Add sample counts to legend
        handles, labels = ax.get_legend_handles_labels()
        new_labels = []

        system_counts = df.groupby("system_prompt").size()

        for label in labels:
            count = system_counts.get(label, 0)
            new_label = f"{label} (n={count})"
            new_labels.append(new_label)

        plt.legend(
            handles,
            new_labels,
            title="System Prompt",
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "system_prompt_impact.png"), dpi=300)
    plt.close()

    # 2.3. Size pattern impact across models
    plt.figure(figsize=(14, 8))

    pattern_impact = df.pivot_table(
        index="model", columns="size_pattern", values="correct", aggfunc="mean"
    )

    if (
        not pattern_impact.empty and pattern_impact.shape[1] > 1
    ):  # Only if we have multiple patterns
        ax = pattern_impact.plot(kind="bar", figsize=(14, 8), ylim=[0, 1])
        plt.title("Impact of Size Patterns on Model Performance")
        plt.ylabel("Accuracy")
        plt.grid(True, axis="y", linestyle="--", alpha=0.7)

        # Add sample counts and pattern descriptions to legend
        handles, labels = ax.get_legend_handles_labels()
        new_labels = []

        pattern_counts = df.groupby("size_pattern").size()

        for label in labels:
            count = pattern_counts.get(label, 0)

            # Add pattern description if it's a named pattern
            if label in SIZE_PATTERNS:
                sizes = SIZE_PATTERNS[label]
                new_label = f"{label}: {sizes} (n={count})"
            else:
                new_label = f"{label} (n={count})"

            new_labels.append(new_label)

        plt.legend(
            handles,
            new_labels,
            title="Size Pattern",
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "size_pattern_impact.png"), dpi=300)
    plt.close()

    # 2.4. Encoding impact across models
    plt.figure(figsize=(14, 8))

    encoding_impact = df.pivot_table(
        index="model", columns="encoding", values="correct", aggfunc="mean"
    )

    if not encoding_impact.empty and encoding_impact.shape[1] > 1:  # Multiple encodings
        ax = encoding_impact.plot(kind="bar", figsize=(14, 8), ylim=[0, 1])
        plt.title("Impact of Encoding Types on Model Performance")
        plt.ylabel("Accuracy")
        plt.grid(True, axis="y", linestyle="--", alpha=0.7)

        # Add sample counts to legend
        handles, labels = ax.get_legend_handles_labels()
        new_labels = []

        encoding_counts = df.groupby("encoding").size()

        for label in labels:
            count = encoding_counts.get(label, 0)
            new_label = f"{label} (n={count})"
            new_labels.append(new_label)

        plt.legend(
            handles,
            new_labels,
            title="Encoding",
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "encoding_impact.png"), dpi=300)
    plt.close()

    # 3. Overall summary statistics
    summary_stats = {
        "overall_accuracy": df["correct"].mean(),
        "model_accuracy": df.groupby("model")["correct"].mean().to_dict(),
        "benchmark_accuracy": df.groupby("benchmark")["correct"].mean().to_dict(),
        "system_prompt_accuracy": df.groupby("system_prompt")["correct"]
        .mean()
        .to_dict(),
        "size_pattern_accuracy": df.groupby("size_pattern")["correct"].mean().to_dict(),
        "encoding_accuracy": df.groupby("encoding")["correct"].mean().to_dict(),
        "n_pairs_accuracy": df.groupby("n_pairs")["correct"].mean().to_dict(),
        "sample_counts": {
            "total": len(df),
            "by_model": df.groupby("model").size().to_dict(),
            "by_benchmark": df.groupby("benchmark").size().to_dict(),
            "by_system_prompt": df.groupby("system_prompt").size().to_dict(),
            "by_size_pattern": df.groupby("size_pattern").size().to_dict(),
            "by_encoding": df.groupby("encoding").size().to_dict(),
        },
    }

    # Save summary stats
    with open(
        os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8"
    ) as f:
        json.dump(summary_stats, f, indent=2)

    # Create a summary text file for quick reading
    with open(
        os.path.join(output_dir, "summary_report.txt"), "w", encoding="utf-8"
    ) as f:
        f.write("=== PERFORMANCE SUMMARY REPORT ===\n\n")
        f.write(f"Overall accuracy: {summary_stats['overall_accuracy']:.4f}\n")
        f.write(f"Total samples: {summary_stats['sample_counts']['total']}\n\n")

        f.write("=== MODEL PERFORMANCE ===\n")
        for model, acc in sorted(
            summary_stats["model_accuracy"].items(), key=lambda x: x[1], reverse=True
        ):
            count = summary_stats["sample_counts"]["by_model"].get(model, 0)
            f.write(f"{model}: {acc:.4f} (n={count})\n")

        f.write("\n=== BENCHMARK PERFORMANCE ===\n")
        for bench, acc in sorted(
            summary_stats["benchmark_accuracy"].items(),
            key=lambda x: x[1],
            reverse=True,
        ):
            count = summary_stats["sample_counts"]["by_benchmark"].get(bench, 0)
            f.write(f"{bench}: {acc:.4f} (n={count})\n")

        f.write("\n=== SYSTEM PROMPT PERFORMANCE ===\n")
        for prompt, acc in sorted(
            summary_stats["system_prompt_accuracy"].items(),
            key=lambda x: x[1],
            reverse=True,
        ):
            count = summary_stats["sample_counts"]["by_system_prompt"].get(prompt, 0)
            f.write(f"{prompt}: {acc:.4f} (n={count})\n")

        f.write("\n=== SIZE PATTERN PERFORMANCE ===\n")
        for pattern, acc in sorted(
            summary_stats["size_pattern_accuracy"].items(),
            key=lambda x: x[1],
            reverse=True,
        ):
            count = summary_stats["sample_counts"]["by_size_pattern"].get(pattern, 0)
            pattern_desc = (
                f"{SIZE_PATTERNS.get(pattern, pattern)}"
                if pattern in SIZE_PATTERNS
                else pattern
            )
            f.write(f"{pattern} {pattern_desc}: {acc:.4f} (n={count})\n")

        f.write("\n=== ENCODING PERFORMANCE ===\n")
        for enc, acc in sorted(
            summary_stats["encoding_accuracy"].items(), key=lambda x: x[1], reverse=True
        ):
            count = summary_stats["sample_counts"]["by_encoding"].get(enc, 0)
            f.write(f"{enc}: {acc:.4f} (n={count})\n")

    print(f"✅ All visualizations created in {output_dir}")
    print(f"Overall accuracy: {summary_stats['overall_accuracy']:.4f}")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Create visualizations from evaluation results"
    )

    parser.add_argument(
        "evaluation_file", help="Path to the evaluation results JSON file"
    )
    parser.add_argument(
        "--output_dir",
        default="evaluation_data/visualizations",
        help="Directory to save visualizations",
    )
    parser.add_argument(
        "--models", nargs="+", help="Filter visualizations to specific models"
    )
    parser.add_argument(
        "--system_prompts",
        nargs="+",
        help="Filter visualizations to specific system prompts",
    )
    parser.add_argument(
        "--patterns", nargs="+", help="Filter visualizations to specific size patterns"
    )
    parser.add_argument(
        "--min_samples",
        type=int,
        default=1,
        help="Minimum number of samples required for inclusion in charts",
    )

    args = parser.parse_args()

    create_benchmark_visualizations(
        evaluation_file=args.evaluation_file,
        output_dir=args.output_dir,
        models=args.models,
        system_prompts=args.system_prompts,
        patterns=args.patterns,
        min_sample_size=args.min_samples,
    )
