"""Read eval results into copy-pastable format."""
import argparse
from collections import defaultdict
import json
import os
from llm_rules import scenarios


parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="data/systematic/outputs")
parser.add_argument("--single_dir", type=str, default="")
args = parser.parse_args()


class AccuracyMeter:
    def __init__(self):
        self.correct = 0
        self.total = 0

    def update(self, result):
        self.correct += int(result)
        self.total += 1

    @property
    def accuracy(self):
        return self.correct / self.total if self.total else 0


if args.single_dir:
    model_dirs = [args.single_dir]
else:
    model_dirs = [
        os.path.join(args.output_dir, model_dir)
        for model_dir in os.listdir(args.output_dir)
    ]

for model_dir in model_dirs:
    print("\n" + model_dir)

    output_files = sorted(os.listdir(model_dir))
    filelist = []
    for name in scenarios.SCENARIOS:
        filelist.extend([f for f in output_files if f.startswith(name)])
    all_categories = set()
    results = defaultdict(AccuracyMeter)
    results_by_category = defaultdict(AccuracyMeter)
    for filename in filelist:
        fullname = filename[: -len(".jsonl")]

        with open(os.path.join(model_dir, filename)) as f:
            outputs = [json.loads(l.strip()) for l in f.readlines()]

        all_categories.update([o["category"] for o in outputs])
        for output in outputs:
            passed = output["result"]["passed"]
            results[f"{fullname}"].update(passed)
            results_by_category[f"{fullname}_{output['category']}"].update(passed)

    # Print results in copy-pastable format: for each scenario, print average then all categories
    all_categories = sorted(list(all_categories))
    result_str = "name,Average," + ",".join(all_categories) + "\n"
    for name in results:
        acc = 100 * results[name].accuracy
        result_str += f"{name},{acc:.1f}"
        for category in all_categories:
            name_cat = f"{name}_{category}"
            if name_cat in results_by_category:
                acc = 100 * results_by_category[name_cat].accuracy
                result_str += f",{acc:.1f}"
            else:
                result_str += ",-"
        result_str += "\n"

    print("\ncopypaste:")
    print(result_str)
    # pyperclip.copy(result_str)
