import os
import json
import argparse
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Number of queries per try
# One can include this information in the json files through the evaluate.py file
QUERIES_PER_TRY = 5
SHOW_OUTLIERS = False
SEPARATE_ERROR_AXES = True  # When True, creates separate subplots with individual y-axis scales for each space

# Set default font sizes for plots
plt.rcParams.update(
    {
        "font.size": 14,  # Default font size
        "axes.labelsize": 16,  # Axis label font size
        "axes.titlesize": 18,  # Title font size
        "xtick.labelsize": 14,  # X-axis tick label font size
        "ytick.labelsize": 14,  # Y-axis tick label font size
        "legend.fontsize": 14,  # Legend font size
    }
)


def print_summary_table(summary_df):
    """Print the summary table to the terminal in a readable format."""
    print("\n📋 Summary Table:")
    # from tabulate import tabulate
    # print(tabulate(summary_df, headers="keys", tablefmt="grid", floatfmt=".4f"))
    # Format the DataFrame as a string with custom formatting for floats
    pd.set_option("display.precision", 4)
    pd.set_option("display.width", 120)
    pd.set_option("display.max_columns", None)

    # Create a copy of the DataFrame without the error_std_within_run_mean column (optional)
    print_df = summary_df.drop(columns=["error_std_within_run_mean"])
    print(print_df.to_string(float_format=lambda x: f"{x:.4f}"))


def save_summary_csv(summary_df, run_summary_df, tries_df, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Save overall summary
    summary_path = os.path.join(output_dir, "summary.csv")
    summary_df.to_csv(summary_path, index=False)
    print(f"✅ Summary saved to {summary_path}")

    # Save run-level summary if available
    if run_summary_df is not None:
        run_summary_path = os.path.join(output_dir, "run_summary.csv")
        run_summary_df.to_csv(run_summary_path, index=False)
        print(f"✅ Run summary saved to {run_summary_path}")

    # Save try-level data if available
    if tries_df is not None:
        tries_path = os.path.join(output_dir, "tries_data.csv")
        tries_df.to_csv(tries_path, index=False)
        print(f"✅ Try-level data saved to {tries_path}")


def plot_boxplots(df, tries_df, run_summary_df, output_dir, space_name_map=None):
    os.makedirs(output_dir, exist_ok=True)

    # Plot with the most detailed data available
    plot_df = tries_df if tries_df is not None else df

    # Apply space name mappings if provided
    if space_name_map:
        # Filter out spaces marked as DROP
        plot_df = plot_df[
            ~plot_df["space"].isin(
                [k for k, v in space_name_map.items() if v == "DROP"]
            )
        ]
        # Rename spaces according to mapping
        plot_df = plot_df.copy()
        plot_df["space"] = plot_df["space"].map(lambda x: space_name_map.get(x, x))

    # Get unique methods for coloring
    methods = plot_df["method"].unique()
    # Create a colormap
    colors = sns.color_palette("husl", len(methods))
    method_color_map = dict(zip(methods, colors))

    if SEPARATE_ERROR_AXES:
        # Create separate subplots for each space with individual y-axis scales
        spaces = plot_df["space"].unique()
        n_spaces = len(spaces)

        # Boxplot with separate axes
        fig, axes = plt.subplots(1, n_spaces, figsize=(6 * n_spaces, 6))
        if n_spaces == 1:
            axes = [axes]  # Make axes iterable for single subplot case

        for ax, space in zip(axes, spaces):
            space_data = plot_df[plot_df["space"] == space]
            sns.boxplot(
                data=space_data,
                x="space",
                y="error",
                hue="method",
                palette=method_color_map,
                showfliers=SHOW_OUTLIERS,
                ax=ax,
            )
            ax.set_title(f"Space: {space}", pad=20)
            ax.set_xlabel("")
            ax.set_ylabel("Error", labelpad=10)
            ax.tick_params(axis="x", rotation=0)
            if ax != axes[-1]:  # Remove legend from all but last subplot
                ax.get_legend().remove()
            else:
                ax.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")

        plt.suptitle(
            "Error Distribution by Method per Space (Separate Scales)",
            y=1.05,
            fontsize=18,
        )
        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, "error_boxplot_separate_scales.png"),
            dpi=300,
            bbox_inches="tight",
        )
        print(
            f"📊 Boxplot with separate scales saved to error_boxplot_separate_scales.png"
        )

        # Violin plot with separate axes
        fig, axes = plt.subplots(1, n_spaces, figsize=(6 * n_spaces, 6))
        if n_spaces == 1:
            axes = [axes]  # Make axes iterable for single subplot case

        for ax, space in zip(axes, spaces):
            space_data = plot_df[plot_df["space"] == space]
            sns.violinplot(
                data=space_data,
                x="space",
                y="error",
                hue="method",
                palette=method_color_map,
                split=False,
                inner="quartile",
                bw_adjust=1.0,
                cut=0,
                ax=ax,
            )
            ax.set_title(f"Space: {space}", pad=20)
            ax.set_xlabel("")
            ax.set_ylabel("Error", labelpad=10)
            ax.tick_params(axis="x", rotation=0)
            ax.set_ylim(0, None)
            if ax != axes[-1]:  # Remove legend from all but last subplot
                ax.get_legend().remove()
            else:
                ax.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")

        plt.suptitle(
            "Error Distribution by Method per Space (Violin Plot, Separate Scales)",
            y=1.05,
            fontsize=18,
        )
        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, "error_violinplot_separate_scales.png"),
            dpi=300,
            bbox_inches="tight",
        )
        print(
            f"📊 Violin plot with separate scales saved to error_violinplot_separate_scales.png"
        )

    # Original combined plots (when SEPARATE_ERROR_AXES is False)
    # 1. Combined error box plot for all methods
    plt.figure(figsize=(16, 6))
    sns.boxplot(
        data=plot_df,
        x="space",
        y="error",
        hue="method",
        palette=method_color_map,
        showfliers=SHOW_OUTLIERS,  # Hide outliers
    )
    plt.title("Error Distribution by Method per Space", pad=20)
    plt.xlabel("Space", labelpad=10)
    plt.ylabel("Error", labelpad=10)
    plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=0)  # Horizontal labels
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "error_boxplot_combined.png"),
        dpi=300,
        bbox_inches="tight",
    )
    print(f"📊 Combined boxplot saved to error_boxplot_combined.png")

    # 1.1 Combined error violin plot for all methods
    plt.figure(figsize=(16, 6))
    sns.violinplot(
        data=plot_df,
        x="space",
        y="error",
        hue="method",
        palette=method_color_map,
        split=False,
        inner="quartile",
        bw_adjust=1.0,
        cut=0,  # Don't extend the density past the data points
    )
    plt.title("Error Distribution by Method per Space (Violin Plot)", pad=20)
    plt.xlabel("Space", labelpad=10)
    plt.ylabel("Error", labelpad=10)
    plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=0)  # Horizontal labels
    plt.ylim(0, None)  # Set y-axis lower limit to 0
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "error_violinplot_combined.png"),
        dpi=300,
        bbox_inches="tight",
    )
    print(f"📊 Combined violin plot saved to error_violinplot_combined.png")

    # 2. Separate plots for each method
    for method in methods:
        method_data = plot_df[plot_df["method"] == method]

        # 2.1 Boxplot
        plt.figure(figsize=(12, 6))
        sns.boxplot(
            data=method_data,
            x="space",
            y="error",
            color=method_color_map[method],
            showfliers=SHOW_OUTLIERS,  # Hide outliers
        )
        plt.title(f"Error Distribution for Method: {method}")
        plt.xticks(rotation=45)
        plt.ylim(0, None)  # Set y-axis lower limit to 0
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"error_boxplot_{method}.png"))
        print(f"📊 Method boxplot saved to error_boxplot_{method}.png")

        # 2.2 Violin Plot
        plt.figure(figsize=(12, 6))
        sns.violinplot(
            data=method_data,
            x="space",
            y="error",
            color=method_color_map[method],
            inner="quartile",
            bw_adjust=1.0,
            cut=0,  # Don't extend the density past the data points
        )
        plt.title(f"Error Distribution for Method: {method} (Violin Plot)")
        plt.xticks(rotation=45)
        plt.ylim(0, None)  # Set y-axis lower limit to 0
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"error_violinplot_{method}.png"))
        print(f"📊 Method violin plot saved to error_violinplot_{method}.png")

    # Runtime plots
    # 1. Combined runtime plot
    plt.figure(figsize=(14, 8))
    sns.boxplot(
        data=plot_df,
        x="space",
        y="runtime",
        hue="method",
        palette=method_color_map,
        showfliers=SHOW_OUTLIERS,  # Hide outliers
    )
    plt.title("Runtime Distribution by Method per Space")
    plt.ylabel("Runtime (seconds)")
    plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "runtime_boxplot_combined.png"))
    print(f"📊 Combined runtime boxplot saved to runtime_boxplot_combined.png")

    # 1.1 Combined runtime violin plot
    plt.figure(figsize=(14, 8))
    sns.violinplot(
        data=plot_df,
        x="space",
        y="runtime",
        hue="method",
        palette=method_color_map,
        split=False,
        inner="quartile",
    )
    plt.title("Runtime Distribution by Method per Space (Violin Plot)")
    plt.ylabel("Runtime (seconds)")
    plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "runtime_violinplot_combined.png"))
    print(f"📊 Combined runtime violin plot saved to runtime_violinplot_combined.png")

    # 2. Separate runtime plots for each method
    for method in methods:
        method_data = plot_df[plot_df["method"] == method]

        # 2.1 Boxplot
        plt.figure(figsize=(12, 6))
        sns.boxplot(
            data=method_data,
            x="space",
            y="runtime",
            color=method_color_map[method],
            showfliers=SHOW_OUTLIERS,  # Hide outliers
        )
        plt.title(f"Runtime Distribution for Method: {method}")
        plt.ylabel("Runtime (seconds)")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"runtime_boxplot_{method}.png"))
        print(f"📊 Method runtime boxplot saved to runtime_boxplot_{method}.png")

        # 2.2 Violin Plot
        plt.figure(figsize=(12, 6))
        sns.violinplot(
            data=method_data,
            x="space",
            y="runtime",
            color=method_color_map[method],
            inner="quartile",
        )
        plt.title(f"Runtime Distribution for Method: {method} (Violin Plot)")
        plt.ylabel("Runtime (seconds)")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"runtime_violinplot_{method}.png"))
        print(f"📊 Method runtime violin plot saved to runtime_violinplot_{method}.png")

    # If we have try-level data, make variability plots
    if tries_df is not None and run_summary_df is not None:
        # 1. Combined run variability plot
        unique_spaces = run_summary_df["space"].unique()
        num_spaces = len(unique_spaces)

        plt.figure(figsize=(16, 8))
        for i, space in enumerate(unique_spaces):
            plt.subplot(1, num_spaces, i + 1)
            space_data = tries_df[tries_df["space"] == space]
            sns.violinplot(
                data=space_data,
                x="run",
                y="error",
                hue="method",
                palette=method_color_map,
                bw_adjust=1.0,
                cut=0,  # Don't extend the density past the data points
            )
            plt.title(f"{space}")
            plt.xlabel("Run")
            plt.ylabel("Error")
            plt.ylim(0, None)  # Set y-axis lower limit to 0
            if i == num_spaces - 1:  # Only show legend for the last subplot
                plt.legend(title="Method", bbox_to_anchor=(1.05, 1), loc="upper left")
            else:
                plt.legend([], [], frameon=False)

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "run_variability_combined.png"))
        print(f"📊 Combined run variability plot saved to run_variability_combined.png")

        # 2. Separate run variability plots for each method
        for method in methods:
            plt.figure(figsize=(16, 8))
            method_data = tries_df[tries_df["method"] == method]

            for i, space in enumerate(unique_spaces):
                if i < num_spaces:
                    plt.subplot(1, num_spaces, i + 1)
                    space_method_data = method_data[method_data["space"] == space]
                    if not space_method_data.empty:
                        sns.violinplot(
                            data=space_method_data,
                            x="run",
                            y="error",
                            color=method_color_map[method],
                            bw_adjust=1.0,
                            cut=0,  # Don't extend the density past the data points
                        )
                        plt.title(f"{space} - {method}")
                        plt.xlabel("Run")
                        plt.ylabel("Error")
                        plt.ylim(0, None)  # Set y-axis lower limit to 0

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"run_variability_{method}.png"))
            print(
                f"📊 Method run variability plot saved to run_variability_{method}.png"
            )


def summarize_to_table(results):
    def find_model(source_dir):
        # Find model name from source_dir, this is because in eval,
        # we (unfortunately) didn't include this in the result record
        parts = set(source_dir.split("/"))
        match = {"CausalNF", "DCM", "DeCaFlow", "NCM", "VACA"} & parts
        if not match:
            raise ValueError(f"No known model found in '{source_dir}'")
        return match.pop()

    df = pd.DataFrame(results)
    df["method"] = df["source_dir"].apply(find_model)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    # Expand the errors from all tries into separate rows
    tries_data = []
    for _, row in df.iterrows():
        for try_idx, (error, failure, runtime) in enumerate(
            zip(row["errors"], row["failures_all"], row["runtimes"])
        ):
            # Instead of skipping, set error and runtime to NaN when all queries failed
            if failure == QUERIES_PER_TRY:
                error = np.nan
                runtime = np.nan
            if abs(float(error)) < 1e-9 and failure > 0:  # TODO: is this right
                # print(
                #     f"\033[1m⚠️ WARNING: Entry with {error} error but {failure} failures (space: {row['space']}, method: {row['method']}, run: {row['run']}, try: {try_idx})\033[0m"
                # )
                pass

            try_data = {
                "space": row["space"],
                "seed": row["seed"],
                "run": row["run"],
                "try": try_idx,
                "method": find_model(row["source_dir"]),
                "error": error,
                "failures": failure,
                "runtime": runtime,
            }
            tries_data.append(try_data)

    tries_df = pd.DataFrame(tries_data)

    # Recalculate run statistics based on filtered tries data
    run_stats = (
        # uniquely identify each run (tries are aggregated)
        tries_df.groupby(["space", "method", "seed", "run"])
        .agg(  # compute statistics for the run
            {
                "error": ["mean", "std"],
                "failures": ["mean", "std"],
                "runtime": ["mean", "std"],
            }
        )
        .reset_index()
    )

    # Flatten column names
    run_stats.columns = [
        "_".join(col).strip("_") if isinstance(col, tuple) else col
        for col in run_stats.columns.values
    ]

    # Rename columns to match expected format
    run_stats = run_stats.rename(
        columns={
            "error_mean": "run_error_mean",
            "error_std": "run_error_std",
            "failures_mean": "run_failures_mean",
            "failures_std": "run_failures_std",
            "runtime_mean": "run_runtime_mean",
            "runtime_std": "run_runtime_std",
        }
    )

    # Add num_tries column
    run_tries_count = (
        tries_df.groupby(["space", "method", "seed", "run"])
        .size()
        .reset_index(name="num_tries")
    )
    run_stats = pd.merge(
        run_stats, run_tries_count, on=["space", "method", "seed", "run"]
    )

    # Add source_dir from original df
    source_dirs = df[["space", "method", "seed", "run", "source_dir"]].drop_duplicates()
    run_summary = pd.merge(
        run_stats, source_dirs, on=["space", "method", "seed", "run"]
    )

    # Group by space and method to get the overall statistics
    grouped = run_summary.groupby(["space", "method"])
    summary = grouped.agg(
        {
            "run_error_mean": [
                "mean",  # Average mean error across runs
                "std",
                "min",
                "max",
            ],
            "run_error_std": ["mean"],  # Average of standard deviations
            "run_failures_mean": ["mean", "sum", "count"],
            "run_runtime_mean": ["mean", "std", "sum"],  # Added sum for total runtime
        }
    ).reset_index()

    # Rename columns for clarity
    summary.columns = [
        "space",
        "method",
        "error_mean",
        "error_std",
        "error_min",
        "error_max",
        "error_std_within_run_mean",  # Average std dev within runs (aggregated between tries)
        "failures_mean",
        "failures_total",
        "runs",
        "runtime_mean",
        "runtime_std",
        "runtime_total",  # Total runtime across all runs
    ]

    summary["total_queries"] = summary["runs"] * QUERIES_PER_TRY
    summary["failure_rate"] = summary["failures_total"] / summary["total_queries"]
    return tries_df, run_summary, summary


def load_all_results(result_dirs, spaces_of_interest=None):
    all_results = []
    for result_dir in result_dirs:
        result_dir = Path(result_dir)
        print(f"🔍 Loading from {result_dir}")

        if spaces_of_interest:
            # If SoIs are provided, only load files matching those spaces
            for soi in spaces_of_interest:
                pattern = f"result_{soi}_*.json"
                matching_files = list(result_dir.glob(pattern))
                if matching_files:
                    for file in matching_files:
                        with open(file, "r") as f:
                            result = json.load(f)
                            result["source_dir"] = str(result_dir)
                            all_results.append(result)
                else:
                    print(f"⚠️ No results found for space '{soi}' in {result_dir}")
        else:
            # If no SoIs provided, load all result files
            for file in result_dir.glob("result_*.json"):
                with open(file, "r") as f:
                    result = json.load(f)
                    result["source_dir"] = str(result_dir)
                    all_results.append(result)

    return all_results


def run_analysis(result_dirs, output_dir, spaces_of_interest=None, space_name_map=None):
    if spaces_of_interest:
        print(f"🔍 Filtering for spaces of interest: {', '.join(spaces_of_interest)}")

    results = load_all_results(result_dirs, spaces_of_interest)

    if not results:
        print("❌ No results found. Please check your paths and space filters.")
        return

    tries_df, run_summary, summary = summarize_to_table(results)

    save_summary_csv(summary, run_summary, tries_df, output_dir)
    plot_boxplots(summary, tries_df, run_summary, output_dir, space_name_map)

    # Print summary table to terminal
    print_summary_table(summary)


# python summarize_results.py results/method1 results/method2 --output_dir analysis/ --SoIs space1 space2
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "result_dirs", nargs="+", help="Paths to one or more result directories"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="analysis",
        help="Where to save summary and plots",
    )
    parser.add_argument(
        "--SoIs",
        nargs="+",
        type=str,
        help="Optional list of spaces of interest to filter results",
    )
    parser.add_argument(
        "--space_names",
        nargs="+",
        type=str,
        help="Space name mappings in format 'old_name:new_name'. Use 'DROP' as new_name to exclude a space.",
    )
    args = parser.parse_args()

    # Parse space name mappings
    space_name_map = None
    if args.space_names:
        space_name_map = {}
        for mapping in args.space_names:
            try:
                old_name, new_name = mapping.split(":")
                space_name_map[old_name] = new_name
            except ValueError:
                print(
                    f"⚠️ Invalid space name mapping format: {mapping}. Expected 'old_name:new_name'"
                )
                continue

    run_analysis(args.result_dirs, args.output_dir, args.SoIs, space_name_map)
