import os
import json
import argparse
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def print_summary_table(summary_df):
    """Print the summary table to the terminal in a readable format."""
    print("\n📋 Summary Table:")
    # from tabulate import tabulate
    # print(tabulate(summary_df, headers="keys", tablefmt="grid", floatfmt=".4f"))
    # Format the DataFrame as a string with custom formatting for floats
    pd.set_option("display.precision", 4)
    pd.set_option("display.width", 120)
    pd.set_option("display.max_columns", None)
    print(summary_df.to_string(float_format=lambda x: f"{x:.4f}"))


def save_summary_csv(summary_df, run_summary_df, tries_df, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Save overall summary
    summary_path = os.path.join(output_dir, "summary.csv")
    summary_df.to_csv(summary_path, index=False)
    print(f"✅ Summary saved to {summary_path}")

    # Save run-level summary if available
    if run_summary_df is not None:
        run_summary_path = os.path.join(output_dir, "run_summary.csv")
        run_summary_df.to_csv(run_summary_path, index=False)
        print(f"✅ Run summary saved to {run_summary_path}")

    # Save try-level data if available
    if tries_df is not None:
        tries_path = os.path.join(output_dir, "tries_data.csv")
        tries_df.to_csv(tries_path, index=False)
        print(f"✅ Try-level data saved to {tries_path}")


def plot_boxplots(df, tries_df, run_summary_df, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Plot with the most detailed data available
    plot_df = tries_df if tries_df is not None else df

    # Get unique methods for coloring
    methods = plot_df["method"].unique()
    # Create a colormap
    colors = sns.color_palette("husl", len(methods))
    method_color_map = dict(zip(methods, colors))

    # 1. Combined error box plot for all methods
    plt.figure(figsize=(14, 8))
    sns.boxplot(
        data=plot_df, x="space", y="error", hue="method", palette=method_color_map
    )
    plt.title("Error Distribution by Method per Space (All Methods)")
    plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "error_boxplot_combined.png"))
    print(f"📊 Combined boxplot saved to error_boxplot_combined.png")

    # 2. Separate plots for each method
    for method in methods:
        method_data = plot_df[plot_df["method"] == method]
        plt.figure(figsize=(12, 6))
        sns.boxplot(
            data=method_data, x="space", y="error", color=method_color_map[method]
        )
        plt.title(f"Error Distribution for Method: {method}")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"error_boxplot_{method}.png"))
        print(f"📊 Method boxplot saved to error_boxplot_{method}.png")

    # Runtime plots
    # 1. Combined runtime plot
    plt.figure(figsize=(14, 8))
    sns.boxplot(
        data=plot_df, x="space", y="runtime", hue="method", palette=method_color_map
    )
    plt.title("Runtime Distribution by Method per Space")
    plt.ylabel("Runtime (seconds)")
    plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "runtime_boxplot_combined.png"))
    print(f"📊 Combined runtime boxplot saved to runtime_boxplot_combined.png")

    # 2. Separate runtime plots for each method
    for method in methods:
        method_data = plot_df[plot_df["method"] == method]
        plt.figure(figsize=(12, 6))
        sns.boxplot(
            data=method_data, x="space", y="runtime", color=method_color_map[method]
        )
        plt.title(f"Runtime Distribution for Method: {method}")
        plt.ylabel("Runtime (seconds)")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"runtime_boxplot_{method}.png"))
        print(f"📊 Method runtime boxplot saved to runtime_boxplot_{method}.png")

    # If we have try-level data, make variability plots
    if tries_df is not None and run_summary_df is not None:
        # 1. Combined run variability plot
        unique_spaces = run_summary_df["space"].unique()
        num_spaces = len(unique_spaces)

        plt.figure(figsize=(16, 8))
        for i, space in enumerate(unique_spaces):
            plt.subplot(1, num_spaces, i + 1)
            space_data = tries_df[tries_df["space"] == space]
            sns.violinplot(
                data=space_data,
                x="run",
                y="error",
                hue="method",
                palette=method_color_map,
            )
            plt.title(f"{space}")
            plt.xlabel("Run")
            plt.ylabel("Error")
            if i == num_spaces - 1:  # Only show legend for the last subplot
                plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
            else:
                plt.legend([], [], frameon=False)

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "run_variability_combined.png"))
        print(f"📊 Combined run variability plot saved to run_variability_combined.png")

        # 2. Separate run variability plots for each method
        for method in methods:
            plt.figure(figsize=(16, 8))
            method_data = tries_df[tries_df["method"] == method]

            for i, space in enumerate(unique_spaces):
                if i < num_spaces:
                    plt.subplot(1, num_spaces, i + 1)
                    space_method_data = method_data[method_data["space"] == space]
                    if not space_method_data.empty:
                        sns.violinplot(
                            data=space_method_data,
                            x="run",
                            y="error",
                            color=method_color_map[method],
                        )
                        plt.title(f"{space} - {method}")
                        plt.xlabel("Run")
                        plt.ylabel("Error")

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"run_variability_{method}.png"))
            print(
                f"📊 Method run variability plot saved to run_variability_{method}.png"
            )


def summarize_to_table(results):
    df = pd.DataFrame(results)
    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    # Expand the errors from all tries into separate rows
    tries_data = []
    for _, row in df.iterrows():
        for try_idx, (error, failure, runtime) in enumerate(
            zip(row["errors"], row["failures_all"], row["runtimes"])
        ):
            try_data = {
                "space": row["space"],
                "seed": row["seed"],
                "run": row["run"],
                "try": try_idx,
                "method": row["method"],
                "error": error,
                "failures": failure,
                "runtime": runtime,
            }
            tries_data.append(try_data)

    tries_df = pd.DataFrame(tries_data)

    # Run summary is already pre-computed in the original data
    # Just extract the relevant columns
    run_summary = df[
        [
            "space",
            "method",
            "seed",
            "run",
            "run_error_mean",
            "run_error_std",
            "run_failures_mean",
            "run_failures_std",
            "run_runtime_mean",
            "run_runtime_std",
            "num_tries",
            "source_dir",
        ]
    ].copy()

    # Group by space and method to get the overall statistics
    # Use the run-level statistics which are already pre-computed
    grouped = df.groupby(["space", "method"])
    summary = grouped.agg(
        {
            "run_error_mean": [
                "mean",  # Average mean error across runs
                "std",
                "min",
                "max",
            ],
            "run_error_std": ["mean"],  # Average of standard deviations
            "run_failures_mean": ["mean", "sum", "count"],
            "run_runtime_mean": ["mean", "std"],
        }
    ).reset_index()

    # Rename columns for clarity
    summary.columns = [
        "space",
        "method",
        "error_mean",
        "error_std",
        "error_min",
        "error_max",
        "error_std_within_run_mean",  # Average std dev within runs
        "failures_mean",
        "failures_total",
        "runs",
        "runtime_mean",
        "runtime_std",
    ]

    summary["failure_rate"] = summary["failures_total"] / summary["runs"]
    return tries_df, run_summary, summary


def load_all_results(result_dirs):
    all_results = []
    for result_dir in result_dirs:
        result_dir = Path(result_dir)
        print(f"🔍 Loading from {result_dir}")
        for file in result_dir.glob("result_*.json"):
            with open(file, "r") as f:
                result = json.load(f)
                result["source_dir"] = str(result_dir)
                all_results.append(result)
    return all_results


def run_analysis(result_dirs, output_dir):
    results = load_all_results(result_dirs)
    tries_df, run_summary, summary = summarize_to_table(results)

    # Print summary table to terminal
    print_summary_table(summary)

    save_summary_csv(summary, run_summary, tries_df, output_dir)
    plot_boxplots(summary, tries_df, run_summary, output_dir)


# python summarize_results.py results/method1 results/method2 --output_dir analysis/
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "result_dirs", nargs="+", help="Paths to one or more result directories"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="analysis",
        help="Where to save summary and plots",
    )
    args = parser.parse_args()

    run_analysis(args.result_dirs, args.output_dir)
