"""
Read the evaluation results from the jsonl files and print the mean, std, and histogram of the scores.

Example usage:
python scripts/helper/read_eval_results.py --eval-prefix <the-evaluation-result-prefix>
"""

import argparse
import collections
import os

import numpy as np
from toolemu.evaluators import EVALUATORS, EVALUATORS2METRICS
from toolemu.utils import METRICS2BIN_THRESHOLDS, convert_to_score, read_file

# default threshold for binarizing the label, upper closed, i.e., [0, threshold) is 0, [threshold, 3] is 1

EVAL_TYPES = EVALUATORS.keys()
parser = argparse.ArgumentParser()
parser.add_argument("eval_prefix", type=str)
parser.add_argument("--dataset", "-d", type=str, default=None)
parser.add_argument(
    "--eval-types", "-m", nargs="+", type=str, choices=EVAL_TYPES, default=EVAL_TYPES
)
parser.add_argument("--bin-threshold", "-bt", type=int, default=None)


args = parser.parse_args()


def get_pred_res(preds, key):
    """Get the result of a specific metric from the prediction results"""
    results = preds["eval_scores"]
    if key not in results:
        raise ValueError(f"Key {key} not found in preds")
    return results[key]


def print_scores(scores):
    """Print mean, std, and histogram of scores"""
    print(f"mean: {np.mean(scores):.4f}, std: {np.std(scores):.4f}")
    counter = collections.Counter(scores)
    for k, v in sorted(counter.items()):
        print(f"score {k}: count {v}")
    print("total count:", len(scores))


for eval_type in args.eval_types:
    input_file = f"{args.eval_prefix}_eval_{eval_type}.jsonl"
    if not os.path.exists(input_file):
        print(f"[Info] no eval results for {eval_type}")
        continue

    eval_preds = read_file(input_file)
    print(f"[Info] Loaded {len(eval_preds)} eval results from {input_file}")
    eval_preds = [p for p in eval_preds if "eval_scores" in p]
    eval_metrics = EVALUATORS2METRICS[eval_type]

    for k in eval_metrics:
        print(f"=== {eval_type}: {k} ===")
        all_scores = []
        for pred in eval_preds:
            try:
                res = convert_to_score(get_pred_res(pred, k))
                all_scores.append(res)
            except:
                print(f"Error processing {pred}")
        print_scores(all_scores)

        if args.bin_threshold is None:
            threshold = METRICS2BIN_THRESHOLDS[k]
        else:
            threshold = args.bin_threshold
        all_scores = [convert_to_score(s, binarize_thres=threshold) for s in all_scores]
        print("-- After binarization: --")
        print_scores(all_scores)
