#!/usr/bin/env python3
import json
import argparse
from pathlib import Path
from statistics import mean

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--pred_dir', type=str, default="pred")
    parser.add_argument("--use_wandb", action='store_true', default=False, help="Enable logging to wandb.")
    parser.add_argument("--wandb_project", type=str, default='PCC', help="Wandb project name.")
    parser.add_argument("--wandb_run_name", type=str, default='longbench_avg', help="Wandb run name.")
    return parser.parse_args(args)

args = parse_args()

# Input / Output paths
INPUT_PATH = Path(f"{args.pred_dir}/{args.model}/result.json")
OUTPUT_PATH = Path(f"{args.pred_dir}/{args.model}/result_avg.json")

# Optional rounding: set to an integer (e.g., 2) or to None to disable rounding
ROUND_DIGITS = 2

# Category-to-files mapping
CATEGORIES = {
    "Single-Document QA": [
        "narrativeqa-full.jsonl",
        "qasper-full.jsonl",
        "multifieldqa_en-full.jsonl",
        "multifieldqa_zh-full.jsonl",
    ],
    "Multi-Document QA": [
        "hotpotqa-full.jsonl",
        "2wikimqa-full.jsonl",
        "musique-full.jsonl",
        "dureader-full.jsonl",
    ],
    "Summarization": [
        "gov_report-full.jsonl",
        "qmsum-full.jsonl",
        "multi_news-full.jsonl",
        "vcsum-full.jsonl",
    ],
    "Few-shot Learning": [
        "trec-full.jsonl",
        "triviaqa-full.jsonl",
        "samsum-full.jsonl",
        "lsht-full.jsonl",
    ],
    "Synthetic Tasks": [
        "passage_count-full.jsonl",
        "passage_retrieval_en-full.jsonl",
        "passage_retrieval_zh-full.jsonl",
    ],
    "Code Completion": [
        "lcc-full.jsonl",
        "repobench_p-full.jsonl",
    ],
}

def main():
    if not INPUT_PATH.exists():
        raise FileNotFoundError(f"Input file not found: {INPUT_PATH}")

    with INPUT_PATH.open("r", encoding="utf-8") as f:
        results = json.load(f)  # dict[str, float]

    category_avgs = {}
    for cat, files in CATEGORIES.items():
        vals = []
        missing = []
        for fn in files:
            if fn in results and isinstance(results[fn], (int, float)):
                vals.append(float(results[fn]))
            else:
                missing.append(fn)

        if not vals:
            # If nothing is present, store None (or skip). Here we store None.
            category_avgs[cat] = None
        else:
            avg = mean(vals)
            if ROUND_DIGITS is not None:
                avg = round(avg, ROUND_DIGITS)
            category_avgs[cat] = avg

        # Optional: print a small log about missing files
        if missing:
            print(f"[Info] {cat}: missing {len(missing)} item(s): {missing}")

    all_scores = [v for v in results.values() if isinstance(v, (int, float))]
    if all_scores:
        overall_avg = mean(all_scores)
        if ROUND_DIGITS is not None:
            overall_avg = round(overall_avg, ROUND_DIGITS)
        category_avgs["Overall Average"] = overall_avg

    if args.use_wandb:
        try:
            import wandb
            wandb.init(project=args.wandb_project, name=args.wandb_run_name, reinit=True)
            wandb.log(category_avgs)
            wandb.finish()
            print(f"Logged results to wandb project {args.wandb_project}, run {args.wandb_run_name}")
        except ImportError:
            print("wandb is not installed. Please install it with 'pip install wandb' to log results.")
        except Exception as e:
            print(f"Could not log to wandb: {e}")

    with OUTPUT_PATH.open("w", encoding="utf-8") as f:
        json.dump(category_avgs, f, ensure_ascii=False, indent=2)

    print(f"Wrote category averages to: {OUTPUT_PATH.resolve()}")

if __name__ == "__main__":
    main()
