import argparse
from typing import List, Dict, Union
from statistics import mean
import datasets
from scripts.utils import load_single_dataset


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bon-dataset", type=str, required=True)
    ap.add_argument("--num-bon",     type=int, required=True)
    ap.add_argument("--rm-tag",      type=str, required=True)
    ap.add_argument("--ref-tag",     type=str, required=False, default=None)
    ap.add_argument("--score-tag",   type=str, required=False, default="scores")
    return ap.parse_args()


def get_best_index(scores: List[List[float]], type):
    if type == "sum":
        score_list = [sum(s) for s in scores]
    elif type == "mean":
        score_list = [mean(s) for s in scores]
    return max(range(len(score_list)), key=lambda i: score_list[i])


def get_bon_index(rm_scores: Union[List[List[float]], List[float]], ref_scores: List[List[float]]=None):
    if ref_scores is not None:
        assert len(rm_scores) == len(ref_scores)

    # calculate the scores
    if isinstance(rm_scores[0], float):
        bon_sum_idx = max(range(len(rm_scores)), key=lambda i: rm_scores[i])
        bon_mean_idx = bon_sum_idx
    elif ref_scores is not None:
        scores: List[List[float]] = list(map(
            lambda rm_logps, ref_logs: [(rml - refl) for (rml, refl) in zip(rm_logps, ref_logs)],
            rm_scores, ref_scores
        ))
        bon_sum_idx = max(range(len(scores)), key=lambda i: scores[i][0])
        bon_mean_idx = max(range(len(scores)), key=lambda i: scores[i][1])
    elif isinstance(rm_scores[0], list):
        scores: List[List[float]] = rm_scores
        bon_sum_idx = max(range(len(scores)), key=lambda i: scores[i][0])
        bon_mean_idx = max(range(len(scores)), key=lambda i: scores[i][1])
    else:
        bon_sum_idx = max(range(len(scores)), key=lambda i: scores[i])
        bon_mean_idx = max(range(len(scores)), key=lambda i: scores[i])
    return {"bon_sum": bon_sum_idx, "bon_mean": bon_mean_idx}


def get_bon_indices(
        all_rm_scores: List[List[List[float]]], 
        all_ref_scores: List[List[List[float]]]=None,
        num_bon: int=64) -> Dict[str, List[int]]:
    
    bon_indices = {"bon_sum": [], "bon_mean": []}
    if all_ref_scores is None:
        all_ref_scores = [None] * len(all_rm_scores)
    for (rm_scores_, ref_scores_) in zip(all_rm_scores, all_ref_scores):
        rm_scores: Union[List[List[float]], List[float]] = rm_scores_[: num_bon]
        ref_scores = ref_scores_[: num_bon] if ref_scores_ is not None else None
        bon_index = get_bon_index(rm_scores, ref_scores)
        bon_indices["bon_sum"].append(bon_index["bon_sum"])
        bon_indices["bon_mean"].append(bon_index["bon_mean"])
    return bon_indices


def prepare_eval_row(row):
    return {
        "sequences_str": row["response"],
        "ground_truth": row["reward_model"]["ground_truth"],
        "data_source": row["data_source"]
    }


def main():
    args = parse_args()
    dataset: datasets.Dataset = load_single_dataset(args.bon_dataset)
    # dataset = dataset.select(range(500))
    print(args.bon_dataset)
    print(args.num_bon)
    bon_indices: Dict[str, List[int]] = get_bon_indices(
        all_rm_scores=dataset[args.rm_tag],
        all_ref_scores=dataset[args.ref_tag] if args.ref_tag is not None else None,
        num_bon=args.num_bon)
    
    # collect the scores
    scores_cnt = {k: [] for k in bon_indices.keys()}
    for k, v in bon_indices.items():
        for (row, idx) in zip(dataset, v):
            scores_cnt[k].append(row[args.score_tag][idx])
    
    for k, v in bon_indices.items():
        score_dict = {}
        count_dict = {}
        for task, score in zip(dataset["data_source"], scores_cnt[k]):
            count_dict[task] = count_dict.get(task, 0) + 1
            if score > 0.5:
                score_dict[task] = score_dict.get(task, 0) + 1
        # print(f"-------{k}-------")
        print(k, end=" ")
        # print(score_dict)
        # print(count_dict)
        for task, cnt in count_dict.items():
            print(task, round(score_dict.get(task, 0) / cnt * 100, 2), "%")
        # print("all", sum(score_dict.values()) / sum(count_dict.values()))
        print(f"----------------------\n")


if __name__ == "__main__":
    main()



"""

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/bon3_print2.py \
    --bon-dataset ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/842_bo64_qwen38b_qrm_100.json \
    --num-bon 64 \
    --rm-tag rmlogp


    
"""
