import pandas as pd
import sys

def analyze_group(csv_files, group_name):
    dfs = [pd.read_csv(f) for f in csv_files]
    df = pd.concat(dfs, ignore_index=True)

    # filter rows where triangles > 0
    filtered = df[(df["triangles_bounding_mesh"] > 0) | (df["triangles_neural_sdf"] > 0)]

    stats = {
        "group": group_name,
        "min_time": filtered["physics_stepping_time_ms"].min(),
        "median_time": filtered["physics_stepping_time_ms"].median(),
        "max_time": filtered["physics_stepping_time_ms"].max(),
        "time_ratio": None,   # placeholder
        "tri_ratio": None,    # placeholder
        "total_time": None,   # placeholder, filled in compute_ratios
        "total_bm": filtered["triangles_bounding_mesh"].sum(),
        "total_nsdf": filtered["triangles_neural_sdf"].sum(),
    }
    return df, stats


def compute_ratios(dfA, dfB, statsA, statsB):
    # Align frames by frame_number
    merged = pd.merge(dfA, dfB, on="frame_number", suffixes=("_A", "_B"))

    # Keep only frames where *both* groups tested triangles
    valid = merged[
        ((merged["triangles_bounding_mesh_A"] > 0) |
         (merged["triangles_neural_sdf_A"] > 0)) &
        ((merged["triangles_bounding_mesh_B"] > 0) |
         (merged["triangles_neural_sdf_B"] > 0))
    ].copy()

    # Compute ratios (B/A)
    valid["time_ratio"] = valid["physics_stepping_time_ms_B"] / valid["physics_stepping_time_ms_A"]
    valid["nsdf_ratio"] = (
        valid["triangles_neural_sdf_B"] /
        valid["triangles_neural_sdf_A"].replace(0, pd.NA)
    )

    # Smallest ratios (acceleration)
    min_time_ratio = valid["time_ratio"].min()
    min_nsdf_ratio = valid["nsdf_ratio"].min()

    # Total stepping time (only frames where both groups tested triangles)
    total_time_A = valid["physics_stepping_time_ms_A"].sum()
    total_time_B = valid["physics_stepping_time_ms_B"].sum()

    statsA["time_ratio"] = 1.0  # baseline
    statsA["tri_ratio"] = 1.0
    statsA["total_time"] = total_time_A

    statsB["time_ratio"] = min_time_ratio
    statsB["tri_ratio"] = min_nsdf_ratio
    statsB["total_time"] = total_time_B

    return [statsA, statsB]


if __name__ == "__main__":
    if len(sys.argv) < 4 or (len(sys.argv) - 2) % 2 != 0:
        print("Usage: python analyze_pairs.py groupA1.csv ... groupB1.csv ... output.csv")
        print("  (Must provide an even number of input CSVs + one output path at the end)")
        sys.exit(1)

    *input_files, out_csv = sys.argv[1:]
    mid = len(input_files) // 2
    groupA_files = input_files[:mid]
    groupB_files = input_files[mid:]

    print(f"Group A (reference): {groupA_files}")
    print(f"Group B (test): {groupB_files}")
    print(f"Output CSV: {out_csv}")

    dfA, statsA = analyze_group(groupA_files, "Group A (reference)")
    dfB, statsB = analyze_group(groupB_files, "Group B (test)")

    results = compute_ratios(dfA, dfB, statsA, statsB)

    pd.DataFrame(results).to_csv(out_csv, index=False)
    print(f"\n✅ Saved summary to {out_csv}")
