import glob
import json
import os
import sys
import argparse
import itertools


def accuracy_for_file(path: str) -> float:
    total = 0
    correct = 0
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                print("Not line")
                continue
            try:
                obj = json.loads(ln)
            except json.JSONDecodeError:
                print("JSONDecodeError")
                continue
            flag = obj.get("correct")
            if flag is None:
                print("No correct")
                continue

            total += 1
            correct += int(flag)
    return (correct / total) if total else 0.0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Print accuracy for each JSONL file matching one or more glob patterns (suffix .jsonl is added automatically)."
    )
    parser.add_argument(
        '-p', '--pattern',
        action='append',
        help='Base filename glob(s) (without .jsonl). Can be used multiple times, e.g. -p "*Random*" -p "*baseline*"'
    )
    parser.add_argument(
        '--sort',
        choices=['acc', 'abc'],
        default='abc',
        help='Sort by accuracy (acc) or by filename (abc)'
    )
    args = parser.parse_args()

    # Prepare glob patterns (append .jsonl if not present)
    raw = args.pattern if args.pattern else ["*"]
    patterns = [p if p.endswith('.jsonl') else f"{p}.jsonl" for p in raw]

    # Gather files from all patterns
    all_files = set(
        itertools.chain.from_iterable(glob.glob(pat) for pat in patterns)
    )
    files = sorted(all_files)

    if not files:
        print(f"No files match patterns: {raw}")
        sys.exit(1)

    # Compute accuracies
    results = [(fname, accuracy_for_file(fname)) for fname in files]

    # Sort results
    if args.sort == 'acc':
        results.sort(key=lambda x: x[1], reverse=True)
    else:
        results.sort(key=lambda x: x[0])

    bases = [fname[:-6] if fname.endswith('.jsonl') else fname for fname, _ in results]
    max_base_len = max(len(b) for b in bases) if bases else 0

    for (fname, acc), base in zip(results, bases):
        print(f"{base:<{max_base_len}}  {acc*100:5.1f}%")