import argparse
import pandas as pd
import numpy as np
import json


def calculate_stats_json(input_path):
    df = pd.read_csv(input_path)

    df["K_Passes"] = df["K_Passes"].apply(
        lambda x: [int(i) for i in x.strip("[]").split(",")]
    )

    grouped = df.groupby(["SceneName", "ModelName"])

    results = []

    for (scene, model), group in grouped:
        k_passes_lists = group["K_Passes"].tolist()

        pass_at_1 = np.mean([passes[0] for passes in k_passes_lists])

        k = len(k_passes_lists[0])
        pass_at_k = np.mean([1 if sum(passes) > 0 else 0 for passes in k_passes_lists])

        k_k_reliability = np.mean(
            [1 if sum(passes) == k else 0 for passes in k_passes_lists]
        )

        results.append(
            {
                "SceneName": scene,
                "Model": model,
                "pass@1": pass_at_1,
                "pass@k": pass_at_k,
                "k/k reliability": k_k_reliability,
                "Count": len(group),
            }
        )

    results_df = pd.DataFrame(results)

    stats_json = {}

    overall_stats = {}
    for model in results_df["Model"].unique():
        model_data = results_df[results_df["Model"] == model]
        overall_stats[model] = {
            "pass_1": float(model_data["pass@1"].mean()),
            "pass_k": float(model_data["pass@k"].mean()),
            "reliability": float(model_data["k/k reliability"].mean()),
            "count": int(model_data["Count"].sum()),
        }

    stats_json["overall"] = overall_stats

    for scene in sorted(results_df["SceneName"].unique()):
        scene_stats = {}
        scene_data = results_df[results_df["SceneName"] == scene]

        for _, row in scene_data.iterrows():
            scene_stats[row["Model"]] = {
                "pass_1": float(row["pass@1"]),
                "pass_k": float(row["pass@k"]),
                "reliability": float(row["k/k reliability"]),
                "count": int(row["Count"]),
            }

        stats_json[scene] = scene_stats

    return stats_json


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input", type=str, required=True, help="Path to evaluation results CSV"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="stats.json",
        help="Path to output JSON file (default: stats.json)",
    )
    args = parser.parse_args()

    stats = calculate_stats_json(args.input)

    with open(args.output, "w") as f:
        json.dump(stats, f, indent=2)

    print(f"Statistics saved to {args.output}")
