import argparse
from typing import List, Dict, Union
from tqdm import tqdm
from statistics import mean
import datasets
from scripts.utils import load_single_dataset


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bon-dataset", type=str, required=True)
    ap.add_argument("--num-bon",     type=int, required=True)
    ap.add_argument("--rm-tag",      type=str, required=True)
    ap.add_argument("--use-decay",   type=str, required=False, default=None)
    ap.add_argument("--ref-tag",     type=str, required=False, default=None)
    ap.add_argument("--score-tag",   type=str, required=False, default="scores")
    return ap.parse_args()


def get_best_index(scores: List[List[float]], type, use_decay=False):
    if use_decay:
        scores = [[ss * (0.95**i) for (i, ss) in enumerate(s)] for s in scores]
    if type == "sum":
        score_list = [sum(s) for s in scores]
    elif type == "mean":
        score_list = [mean(s) for s in scores]
    return max(range(len(score_list)), key=lambda i: score_list[i])


def get_bon_index(rm_scores: Union[List[List[float]], List[float]], ref_scores: List[List[float]]=None, use_decay=False):
    if ref_scores is not None:
        assert len(rm_scores) == len(ref_scores)

    # calculate the scores
    if isinstance(rm_scores[0], float):
        bon_sum_idx = max(range(len(rm_scores)), key=lambda i: rm_scores[i])
        bon_mean_idx = bon_sum_idx
    elif ref_scores is not None:
        scores: List[List[float]] = list(map(
            lambda rm_logps, ref_logs: [(rml - refl) for (rml, refl) in zip(rm_logps, ref_logs)],
            rm_scores, ref_scores
        ))
        bon_sum_idx = get_best_index(scores, "sum", use_decay=use_decay)
        bon_mean_idx= get_best_index(scores, "mean", use_decay=use_decay)
    elif isinstance(rm_scores[0], list):
        scores: List[List[float]] = rm_scores
        bon_sum_idx = get_best_index(scores, "sum", use_decay=use_decay)
        bon_mean_idx = get_best_index(scores, "mean", use_decay=use_decay)
    else:
        bon_sum_idx = max(range(len(scores)), key=lambda i: scores[i])
        bon_mean_idx = max(range(len(scores)), key=lambda i: scores[i])
    return {"bon_sum_index": bon_sum_idx, "bon_mean_index": bon_mean_idx}


def get_bon_indices(
        all_rm_scores: List[List[List[float]]], 
        all_ref_scores: List[List[List[float]]]=None,
        num_bon: int=64, use_decay=False) -> Dict[str, List[int]]:
    
    bon_indices = {"bon_sum_index": [], "bon_mean_index": []}
    if all_ref_scores is None:
        all_ref_scores = [None] * len(all_rm_scores)
    for (rm_scores_, ref_scores_) in tqdm(zip(all_rm_scores, all_ref_scores)):
        rm_scores: Union[List[List[float]], List[float]] = rm_scores_[: num_bon]
        ref_scores = ref_scores_[: num_bon] if ref_scores_ is not None else None
        bon_index = get_bon_index(rm_scores, ref_scores, use_decay=use_decay)
        bon_indices["bon_sum_index"].append(bon_index["bon_sum_index"])
        bon_indices["bon_mean_index"].append(bon_index["bon_mean_index"])
    return bon_indices


def main():
    args = parse_args()
    dataset: datasets.Dataset = load_single_dataset(args.bon_dataset)
    print(args.bon_dataset)
    print(args.num_bon)
    bon_indices: Dict[str, List[int]] = get_bon_indices(
        all_rm_scores=dataset[args.rm_tag],
        all_ref_scores=dataset[args.ref_tag] if args.ref_tag is not None else None,
        num_bon=args.num_bon, use_decay=args.use_decay)
    
    # collect the scores
    scores_cnt = {k: [] for k in bon_indices.keys()}
    for k, v in bon_indices.items():
        for (row, idx) in zip(dataset, v):
            scores_cnt[k].append(row[args.score_tag][idx])
    
    for k, v in bon_indices.items():
        score_dict = {}
        count_dict = {}
        for task, score in zip(dataset["data_source"], scores_cnt[k]):
            count_dict[task] = count_dict.get(task, 0) + 1
            if score > 0.5:
                score_dict[task] = score_dict.get(task, 0) + 1
        print(f"-------{k}-------")
        print(score_dict)
        print(count_dict)
        for task, cnt in count_dict.items():
            print(task, round(score_dict.get(task, 0) / cnt * 100, 2), "%")
        print("all", sum(score_dict.values()) / sum(count_dict.values()))
        print(f"----------------------\n\n\n")


if __name__ == "__main__":
    main()



"""

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/bon3_print.py \
    --bon-dataset ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/842_bo64_llama321b_qrm.json \
    --num-bon 4 \
    --rm-tag rmlogp

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/bon3_print.py \
    --bon-dataset ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/842_bo64_qwen38b_qrm.json \
    --num-bon 4 \
    --rm-tag rmlogp

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/bon3_print.py \
    --bon-dataset ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/842_bo64_llama323b_ref1.json \
    --num-bon 4 \
    --use-decay true \
    --rm-tag reflogp


~/verl_cs/.conda/bin/python ~/verl_cs/scripts/bon3_print2.py --bon-dataset ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/842_bo64_llama321b_qrm_100.json --num-bon 64 --rm-tag rmlogp
     

"""
