from hr2r.evaluate.matheval import evaluator_map
import json
import csv

if __name__ == "__main__":
    base_path = "/path/to/output"
    dataset_options = ["amc23", "aime25", "gsm8k", "math500", "olympiadbench", "chmath"]
    for i in dataset_options:
        if i in base_path:
            dataset_name = i
            print(f"Dataset name: {dataset_name}")
            break
    jsonl_path = base_path + "samples.jsonl"
    csv_path = base_path + "detailed_results.csv"
    # build mapping from (problem_id, sample_idx) to math_verify result
    verify_map = {}
    with open(jsonl_path, "r") as f:
        # analyze accuracy
        correct = 0
        total = 0
        for line in f:
            data = json.loads(line)
            solution_str = data["output"]
            ground_truth = data["correct_answer"]
            sample_idx = data["sample"]
            problem_id = data["id"]
            is_correct, extracted = evaluator_map[dataset_name].rule_judge(solution_str, ground_truth)
            verify_map[(str(problem_id), int(sample_idx))] = bool(is_correct)
            if is_correct:
                correct += 1
            total += 1
            print(extracted)
    print(f"Accuracy: {correct / total}")

    # update csv detailed results with math_verify column
    with open(csv_path, "r", newline="") as rf:
        reader = csv.DictReader(rf)
        fieldnames = list(reader.fieldnames) if reader.fieldnames else []
        if "math_verify" not in fieldnames:
            fieldnames.append("math_verify")
        rows = []
        for row in reader:
            pid = row.get("problem_id")
            try:
                sidx = int(row.get("sample_idx"))
            except (TypeError, ValueError):
                sidx = None
            val = verify_map.get((str(pid), sidx)) if sidx is not None else None
            row["math_verify"] = "True" if val is True else ("False" if val is False else "")
            rows.append(row)
    with open(csv_path, "w", newline="") as wf:
        writer = csv.DictWriter(wf, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    # aggregate verify_map by problem_id to update evaluation_stats.csv
    # This produces per-problem correct count and total samples
    per_problem_stats = {}
    for (pid, sidx), is_ok in verify_map.items():
        stats = per_problem_stats.setdefault(str(pid), {"correct": 0, "total": 0})
        stats["total"] += 1
        if is_ok:
            stats["correct"] += 1

    eval_stats_path = base_path + "evaluation_stats.csv"
    try:
        with open(eval_stats_path, "r", newline="") as rf:
            reader = csv.DictReader(rf)
            eval_fieldnames = list(reader.fieldnames) if reader.fieldnames else []

            eval_rows = []
            total_row_idx = None
            agg_total_correct = 0
            agg_total_samples = 0
            for idx, row in enumerate(reader):
                pid = row.get("problem_id")
                if pid == "Total Accuracy":
                    total_row_idx = idx
                    eval_rows.append(row)
                    continue

                stats = per_problem_stats.get(str(pid))
                if stats and stats["total"] > 0:
                    row["correct_count"] = str(stats["correct"])
                    row["total_samples"] = str(stats["total"])
                    row["accuracy"] = f"{(stats['correct'] / stats['total']):.3f}"
                    agg_total_correct += stats["correct"]
                    agg_total_samples += stats["total"]
                eval_rows.append(row)

            # update total row if present
            if total_row_idx is not None and 0 <= total_row_idx < len(eval_rows):
                total_row = eval_rows[total_row_idx]
                if agg_total_samples > 0:
                    total_row["correct_count"] = str(agg_total_correct)
                    total_row["total_samples"] = str(agg_total_samples)
                    total_row["accuracy"] = f"{(agg_total_correct / agg_total_samples):.3f}"
                eval_rows[total_row_idx] = total_row

        with open(eval_stats_path, "w", newline="") as wf:
            writer = csv.DictWriter(wf, fieldnames=eval_fieldnames)
            writer.writeheader()
            writer.writerows(eval_rows)
    except FileNotFoundError:
        # evaluation_stats.csv not found, skip silently
        pass