import tools
import json
import argparse
SAMPLE_NUM = 20


def score_all_result(result_jsons):
    group_size = SAMPLE_NUM
    assert len(
        result_jsons) % group_size == 0, "Judger all test jsons length is not a multiple of group size"
    groups = []
    for _ in range(group_size):
        groups.append([])
    for i in range(0, len(result_jsons), group_size):
        for j in range(group_size):
            groups[j].append(result_jsons[i+j])

    res = []

    for idx in range(0, len(groups)):
        res.append([])
        group = groups[idx]
        for i in range(len(group)):
            res[idx].append(group[i]["score"])

    return res


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--score_name", type=str, required=True)
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument('--traindataset', type=str, required=True,
                        choices=["datasets--allenai--reward-bench-2", "datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG"],)
    parser.add_argument("--subset", type=str, required=True)

    args = parser.parse_args()
    score_name = args.score_name
    print(f"Score name: {score_name}",
          f"{tools.machine_pather()}/works/DPO/judge/output/{args.traindataset}/{args.subset}/{args.model}/{score_name}/score.jsonl")
    all_jsons = tools.read_jsonl(
        f"{tools.machine_pather()}/works/DPO/judge/output/{args.traindataset}/{args.subset}/{args.model}/{score_name}/score.jsonl")
    all_result = score_all_result(all_jsons)

    assert len(
        all_result) == SAMPLE_NUM, "Judger all test jsons length should be equal to SAMPLE_NUM"
    for each in range(SAMPLE_NUM):
        output_file = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/{args.subset}/{args.model}/output_{each}_eval.json"
        dump_json = json.load(
            open(output_file, "r"))
        tmp = []

        for each_score in all_result[each]:
            tmp.append(each_score)
        dump_json[score_name] = tmp
        assert len(dump_json[score_name]) == len(
            dump_json[f"correctness"])

        json.dump(dump_json, open(
            output_file, "w"), indent=2, ensure_ascii=False)
    print("scores added to output files.")
