import argparse
import pandas as pd
import numpy as np


def calculate_stats(input_path):
    df = pd.read_csv(input_path)

    df["K_Passes"] = df["K_Passes"].apply(
        lambda x: [int(i) for i in x.strip("[]").split(",")]
    )

    grouped = df.groupby(["SceneName", "ModelName"])

    results = []

    for (scene, model), group in grouped:
        k_passes_lists = group["K_Passes"].tolist()

        pass_at_1 = np.mean([passes[0] for passes in k_passes_lists])

        k = len(k_passes_lists[0])
        pass_at_k = np.mean([1 if sum(passes) > 0 else 0 for passes in k_passes_lists])

        k_k_reliability = np.mean(
            [1 if sum(passes) == k else 0 for passes in k_passes_lists]
        )

        results.append(
            {
                "SceneName": scene,
                "Model": model,
                "pass@1": pass_at_1,
                "pass@k": pass_at_k,
                "k/k reliability": k_k_reliability,
                "Count": len(group),
            }
        )

    results_df = pd.DataFrame(results)

    print("\nOverall Statistics:")
    overall = results_df.groupby("Model")[
        ["pass@1", "pass@k", "k/k reliability", "Count"]
    ].mean()
    overall["Count"] = results_df.groupby("Model")["Count"].sum().astype(int)
    print(overall)

    print("\nPer-Scene Statistics:")
    for scene in sorted(results_df["SceneName"].unique()):
        print(f"\n{scene}:")
        scene_stats = results_df[results_df["SceneName"] == scene].sort_values(
            "pass@1", ascending=False
        )
        print(scene_stats[["Model", "pass@1", "pass@k", "k/k reliability", "Count"]])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input", type=str, required=True, help="Path to evaluation results CSV"
    )
    args = parser.parse_args()

    calculate_stats(args.input)
