from pm_kvq.evaluation.eval_ifeval.evaluation_main import InputExample, test_instruction_following_loose, test_instruction_following_strict


def score(predictions, references):
    prompt_strict_correct, prompt_strict_total = 0, 0
    inst_strict_correct, inst_strict_total = 0, 0
    prompt_loose_correct, prompt_loose_total = 0, 0
    inst_loose_correct, inst_loose_total = 0, 0
    details = {}
    for index, (pred, refer) in enumerate(zip(predictions, references)):
        input = InputExample(key=refer["key"], instruction_id_list=refer["instruction_id_list"], prompt=refer["prompt"], kwargs=refer["kwargs"])
        for kwarg in input.kwargs:
            for k in list(kwarg.keys()):
                if kwarg[k] is None:
                    kwarg.pop(k, None)

        # strict
        example = test_instruction_following_strict(input, pred)
        follow_instruction_list = example.follow_instruction_list
        instruction_id_list = example.instruction_id_list
        prompt_strict_total += 1
        is_strict_correct = all(follow_instruction_list)
        prompt_strict_correct += is_strict_correct
        inst_strict_total += len(instruction_id_list)
        inst_strict_correct += sum(follow_instruction_list)

        # loose
        example = test_instruction_following_loose(input, pred)
        follow_instruction_list = example.follow_instruction_list
        instruction_id_list = example.instruction_id_list
        prompt_loose_total += 1
        is_loose_correct = all(follow_instruction_list)
        prompt_loose_correct += is_loose_correct
        inst_loose_total += len(instruction_id_list)
        inst_loose_correct += sum(follow_instruction_list)

        if is_strict_correct:
            grade = "strict"
        elif is_loose_correct:
            grade = "loose"
        else:
            grade = "none"

        details[str(index)] = {"pred": pred, "refer": refer, "is_strict_correct": is_strict_correct, "is_loose_correct": is_loose_correct, "is_correct": is_strict_correct, "grade": grade}

    results = {
        "Prompt-level-strict-accuracy": prompt_strict_correct / prompt_strict_total * 100,
        "Inst-level-strict-accuracy": inst_strict_correct / inst_strict_total * 100,
        "Prompt-level-loose-accuracy": prompt_loose_correct / prompt_loose_total * 100,
        "Inst-level-loose-accuracy": inst_loose_correct / inst_loose_total * 100,
        # "details": details,
    }
    # print(results)
    return results

