"""Read eval results into copy-pastable format."""
import argparse
from collections import defaultdict
import json
import numpy as np
import os
from llm_rules import scenarios


parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="data/systematic/outputs")
parser.add_argument("--single_dir", type=str, default="")
args = parser.parse_args()


class AccuracyMeter:
    def __init__(self):
        self.correct = 0
        self.total = 0

    def update(self, result):
        self.correct += int(result)
        self.total += 1

    @property
    def accuracy(self):
        return self.correct / self.total if self.total else 0


if args.single_dir:
    model_dirs = [args.single_dir]
else:
    model_dirs = [
        os.path.join(args.output_dir, model_dir)
        for model_dir in os.listdir(args.output_dir)
    ]

for model_dir in model_dirs:
    print("\n" + model_dir)

    output_files = sorted(os.listdir(model_dir))
    filelist = []

    for name in scenarios.SCENARIOS:
        filelist.extend([f for f in output_files if f.startswith(name)])

    results = defaultdict(list)

    for filename in filelist:
        fullname = filename[: -len(".jsonl")]

        with open(os.path.join(model_dir, filename)) as f:
            outputs = [json.loads(l.strip()) for l in f.readlines()]

        for output in outputs:
            prediction = output["prediction"]
            label = output["label"]
            results[f"{fullname}"].append((prediction, label))

    # Print results in copy-pastable format: for each scenario, print average then all categories
    result_str = "name,Accuracy,TP,TN,FP,FN" + "\n"
    for name in results:
        data = np.array(results[name])
        predictions = data[:, 0]
        labels = data[:, 1]
        acc = 100 * np.mean(predictions == labels)
        tp = np.sum(np.logical_and(predictions, labels))
        tn = np.sum(np.logical_and(~predictions, ~labels))
        fp = np.sum(np.logical_and(predictions, ~labels))
        fn = np.sum(np.logical_and(~predictions, labels))
        result_str += f"{name},{acc:.1f},{tp},{tn},{fp},{fn}\n"

    print("\ncopypaste:")
    print(result_str)
