import argparse
import glob
import json
import math
import os
from typing import Dict, List

import numpy as np
import pandas as pd


def load_all_results(results_dir: str) -> List[Dict]:
    pattern = os.path.join(results_dir, "*.json")
    files = glob.glob(pattern)

    files = [f for f in files if "summary" not in os.path.basename(f)]

    results = []
    for filepath in files:
        try:
            with open(filepath, "r") as f:
                data = json.load(f)
                results.append(data)
        except Exception as e:
            print(f"Warning: Could not load {filepath}: {e}")

    return results


def aggregate_results(results: List[Dict]) -> pd.DataFrame:
    if not results:
        return pd.DataFrame()

    columns = [
        "omega",
        "omega_deg",
        "bc_type",
        "method",
        "seed",
        "true_mu",
        "dominant_mu",
        "dominant_coeff",
        "abs_error",
        "rel_error",
        "constraint_violation",
        "l2_error",
        "rel_l2_error",
        "success",
        "training_time",
        "final_loss",
    ]

    data = []
    for r in results:
        row = {col: r.get(col, np.nan) for col in columns}
        data.append(row)

    df = pd.DataFrame(data)

    df = df.sort_values(["omega", "bc_type", "method", "seed"]).reset_index(drop=True)

    return df


def compute_statistics(df: pd.DataFrame) -> Dict:
    stats = {}

    stats["total_experiments"] = len(df)
    stats["total_success"] = df["success"].sum()
    stats["overall_success_rate"] = df["success"].mean() * 100

    for method in df["method"].unique():
        method_df = df[df["method"] == method]
        prefix = f"{method}"

        stats[f"{prefix}_count"] = len(method_df)
        stats[f"{prefix}_success_rate"] = method_df["success"].mean() * 100
        stats[f"{prefix}_mean_rel_error"] = method_df["rel_error"].mean()
        stats[f"{prefix}_std_rel_error"] = method_df["rel_error"].std()
        stats[f"{prefix}_median_rel_error"] = method_df["rel_error"].median()
        stats[f"{prefix}_p90_rel_error"] = method_df["rel_error"].quantile(0.9)
        stats[f"{prefix}_mean_constraint_viol"] = method_df[
            "constraint_violation"
        ].mean()
        stats[f"{prefix}_mean_l2_error"] = method_df["rel_l2_error"].mean()

    for bc_type in df["bc_type"].unique():
        bc_df = df[df["bc_type"] == bc_type]
        prefix = f"{bc_type}"

        stats[f"{prefix}_success_rate"] = bc_df["success"].mean() * 100
        stats[f"{prefix}_mean_rel_error"] = bc_df["rel_error"].mean()

    for method in df["method"].unique():
        for bc_type in df["bc_type"].unique():
            subset = df[(df["method"] == method) & (df["bc_type"] == bc_type)]
            prefix = f"{method}_{bc_type}"

            stats[f"{prefix}_success_rate"] = (
                subset["success"].mean() * 100 if len(subset) > 0 else 0
            )
            stats[f"{prefix}_mean_rel_error"] = (
                subset["rel_error"].mean() if len(subset) > 0 else np.nan
            )

    pi_val = math.pi
    reentrant_df = df[df["omega"] > pi_val]
    convex_df = df[df["omega"] <= pi_val]

    for method in df["method"].unique():
        reent_method = reentrant_df[reentrant_df["method"] == method]
        if len(reent_method) > 0:
            stats[f"{method}_reentrant_success_rate"] = (
                reent_method["success"].mean() * 100
            )
            stats[f"{method}_reentrant_mean_error"] = reent_method["rel_error"].mean()

        conv_method = convex_df[convex_df["method"] == method]
        if len(conv_method) > 0:
            stats[f"{method}_convex_success_rate"] = conv_method["success"].mean() * 100
            stats[f"{method}_convex_mean_error"] = conv_method["rel_error"].mean()

    return stats


def compute_improvement_factor(df: pd.DataFrame) -> pd.DataFrame:
    improvement_data = []

    grouped = df.groupby(["omega", "bc_type", "seed"])

    for (omega, bc_type, seed), group in grouped:
        naive_row = group[group["method"] == "naive"]
        constraint_row = group[group["method"] == "constraint"]

        if len(naive_row) == 1 and len(constraint_row) == 1:
            naive_err = naive_row["rel_error"].values[0]
            constraint_err = constraint_row["rel_error"].values[0]

            if constraint_err > 0:
                improvement = naive_err / constraint_err
            else:
                improvement = float("inf") if naive_err > 0 else 1.0

            improvement_data.append(
                {
                    "omega": omega,
                    "omega_deg": naive_row["omega_deg"].values[0],
                    "bc_type": bc_type,
                    "seed": seed,
                    "naive_error": naive_err,
                    "constraint_error": constraint_err,
                    "improvement_factor": improvement,
                }
            )

    return pd.DataFrame(improvement_data)


def generate_summary_tables(df: pd.DataFrame, stats: Dict) -> str:
    lines = []
    lines.append("=" * 70)
    lines.append("Exp7: Large-Scale Wedge/Corner Sweep - Results Summary")
    lines.append("=" * 70)
    lines.append("")

    lines.append(f"Total experiments: {stats['total_experiments']}")
    lines.append(
        f"Overall success rate (rel_error < 5%): {stats['overall_success_rate']:.1f}%"
    )
    lines.append("")

    lines.append("-" * 70)
    lines.append("Results by Method:")
    lines.append("-" * 70)

    for method in ["naive", "constraint"]:
        if f"{method}_count" in stats:
            lines.append(f"\n{method.upper()}:")
            lines.append(f"  Count: {stats[f'{method}_count']}")
            lines.append(f"  Success rate: {stats[f'{method}_success_rate']:.1f}%")
            lines.append(
                f"  Mean relative error: {stats[f'{method}_mean_rel_error']:.4f}%"
            )
            lines.append(
                f"  Std relative error: {stats[f'{method}_std_rel_error']:.4f}%"
            )
            lines.append(
                f"  Median relative error: {stats[f'{method}_median_rel_error']:.4f}%"
            )
            lines.append(
                f"  90th percentile error: {stats[f'{method}_p90_rel_error']:.4f}%"
            )
            lines.append(
                f"  Mean constraint violation: {stats[f'{method}_mean_constraint_viol']:.6f}"
            )

    lines.append("")
    lines.append("-" * 70)
    lines.append("Results by BC Type:")
    lines.append("-" * 70)

    for bc_type in ["DD", "NN", "DN", "ND"]:
        if f"{bc_type}_success_rate" in stats:
            lines.append(f"\n{bc_type}:")
            lines.append(f"  Success rate: {stats[f'{bc_type}_success_rate']:.1f}%")
            lines.append(
                f"  Mean relative error: {stats[f'{bc_type}_mean_rel_error']:.4f}%"
            )

    lines.append("")
    lines.append("-" * 70)
    lines.append("Results by Method and BC Type:")
    lines.append("-" * 70)
    lines.append("")
    lines.append(f"{'Method':<12} {'BC':<6} {'Success%':<10} {'MeanErr%':<12}")
    lines.append("-" * 40)

    for method in ["naive", "constraint"]:
        for bc_type in ["DD", "NN", "DN", "ND"]:
            key_sr = f"{method}_{bc_type}_success_rate"
            key_err = f"{method}_{bc_type}_mean_rel_error"
            if key_sr in stats:
                lines.append(
                    f"{method:<12} {bc_type:<6} {stats[key_sr]:<10.1f} {stats[key_err]:<12.4f}"
                )

    lines.append("")
    lines.append("-" * 70)
    lines.append("Re-entrant vs Convex Corners:")
    lines.append("-" * 70)

    for method in ["naive", "constraint"]:
        re_sr = stats.get(f"{method}_reentrant_success_rate", "N/A")
        re_err = stats.get(f"{method}_reentrant_mean_error", "N/A")
        cv_sr = stats.get(f"{method}_convex_success_rate", "N/A")
        cv_err = stats.get(f"{method}_convex_mean_error", "N/A")

        lines.append(f"\n{method.upper()}:")
        if isinstance(re_sr, float):
            lines.append(
                f"  Re-entrant (omega > pi): {re_sr:.1f}% success, {re_err:.4f}% mean error"
            )
        if isinstance(cv_sr, float):
            lines.append(
                f"  Convex (omega <= pi): {cv_sr:.1f}% success, {cv_err:.4f}% mean error"
            )

    lines.append("")
    lines.append("=" * 70)

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Aggregate Exp7 results")
    parser.add_argument(
        "--results_dir",
        type=str,
        default="results",
        help="Directory containing result JSON files",
    )
    parser.add_argument(
        "--output", type=str, default="exp7.csv", help="Output CSV file"
    )
    parser.add_argument(
        "--summary",
        type=str,
        default="exp7_summary.txt",
        help="Output summary text file",
    )
    parser.add_argument(
        "--stats_json",
        type=str,
        default="exp7_stats.json",
        help="Output statistics JSON file",
    )

    args = parser.parse_args()

    print(f"Loading results from {args.results_dir}...")
    results = load_all_results(args.results_dir)
    print(f"Loaded {len(results)} result files")

    if not results:
        print("No results found. Exiting.")
        return

    df = aggregate_results(results)
    print(f"Created DataFrame with {len(df)} rows")

    df.to_csv(args.output, index=False)
    print(f"Saved results to {args.output}")

    stats = compute_statistics(df)

    with open(args.stats_json, "w") as f:
        stats_serializable = {
            k: float(v) if isinstance(v, (np.floating, np.integer)) else v
            for k, v in stats.items()
        }
        json.dump(stats_serializable, f, indent=2)
    print(f"Saved statistics to {args.stats_json}")

    improvement_df = compute_improvement_factor(df)
    if len(improvement_df) > 0:
        improvement_file = args.output.replace(".csv", "_improvement.csv")
        improvement_df.to_csv(improvement_file, index=False)
        print(f"Saved improvement analysis to {improvement_file}")

        avg_improvement = (
            improvement_df["improvement_factor"]
            .replace([np.inf, -np.inf], np.nan)
            .mean()
        )
        median_improvement = (
            improvement_df["improvement_factor"]
            .replace([np.inf, -np.inf], np.nan)
            .median()
        )
        print(
            f"\nAverage improvement factor (constraint over naive): {avg_improvement:.2f}x"
        )
        print(f"Median improvement factor: {median_improvement:.2f}x")

    summary_text = generate_summary_tables(df, stats)
    with open(args.summary, "w") as f:
        f.write(summary_text)
    print(f"Saved summary to {args.summary}")

    print("\n" + summary_text)


if __name__ == "__main__":
    main()
