import collections
import json
import pathlib

import divergent_memories.data.constants as _constants


def eval_tiny_mmlu(inference_path: pathlib.Path) -> dict[str, int | float | dict[str, int]]:
    with open(inference_path, "r") as f:
        raw_results = tuple(map(json.loads, f))

    errors = collections.defaultdict(int)
    num_correct = 0
    num_non_errors = 0

    for raw_result in raw_results:
        completion = raw_result["inference_completion"].strip()
        # FIXME: could report by subject

        if len(completion) != 1:
            errors["Invalid length"] += 1
        else:
            if not (0 <= (ord(completion) - ord("A")) < 4):  # always 4 choices
                errors["Invalid option"] += 1
            else:
                num_non_errors += 1
                num_correct += completion == raw_result["expected_key"]

    return {
        "num_correct": num_correct,
        "num_non_errors": num_non_errors,
        "errors": errors,
    }


def eval_full_vision_mc(inference_path: pathlib.Path) -> dict[str, int | float | dict[str, int]]:
    with open(inference_path, "r") as f:
        raw_results = tuple(map(json.loads, f))

    errors = collections.defaultdict(int)
    num_correct = 0
    num_non_errors = 0

    for raw_result in raw_results:
        completion = raw_result["inference_completion"].strip()
        if len(completion) != 1:
            errors["Invalid length"] += 1
        else:
            # TODO: hardcoded number of choices
            if not (0 <= (ord(completion) - ord("A")) < 10):  # 10 choices
                errors["Invalid option"] += 1
            else:
                num_non_errors += 1
                num_correct += completion == raw_result["expected_key"]

    return {
        "num_correct": num_correct,
        "num_non_errors": num_non_errors,
        "errors": errors,
    }


def eval_full_vision_natural_name(inference_path: pathlib.Path) -> dict[str, int | float | dict[str, int]]:
    with open(inference_path, "r") as f:
        raw_results = tuple(map(json.loads, f))

    errors = collections.defaultdict(int)
    num_correct = 0
    num_non_errors = 0

    for raw_result in raw_results:
        completion = raw_result["inference_completion"].strip()
        if len(completion) != len(raw_result["name"]):
            errors["Invalid length"] += 1
        else:
            num_non_errors += 1
            num_correct += completion.lower() == raw_result["expected_response"].lower()

    return {
        "num_correct": num_correct,
        "num_non_errors": num_non_errors,
        "errors": errors,
    }


def eval_full_vision_natural_binary(inference_path: pathlib.Path) -> dict[str, int | float | dict[str, int]]:
    with open(inference_path, "r") as f:
        raw_results = tuple(map(json.loads, f))

    errors = collections.defaultdict(int)
    num_tp = num_tn = num_fp = num_fn = 0

    for raw_result in raw_results:
        completion = raw_result["inference_completion"].lower().strip()

        if not (completion.startswith("yes") or completion.startswith("no")):
            errors["Invalid format"] += 1
            continue

        expected = raw_result["expected_response"].lower()
        if completion.startswith("yes"):
            if expected == "yes":
                num_tp += 1
            else:
                assert expected == "no"
                num_fp += 1
        else:
            assert completion.startswith("no")
            if expected == "no":
                num_tn += 1
            else:
                assert expected == "yes"
                num_fn += 1

    return {
        "num_tp": num_tp,
        "num_tn": num_tn,
        "num_fp": num_fp,
        "num_fn": num_fn,
        "errors": errors,
    }


def eval_full_description_mc(inference_path: pathlib.Path, parse_strict: bool = True) -> dict:
    with open(inference_path, "r") as f:
        raw_results = tuple(map(json.loads, f))

    errors = {direction: collections.defaultdict(int) for direction in ("forward", "reverse")}
    results_per_direction = {}
    for direction in ("forward", "reverse"):
        results_per_direction[direction] = dict()
        for concept_type in _constants.CONCEPT_TO_SYNTHETIC_MAP.keys():
            results_per_direction[direction][concept_type] = {
                "num_correct": 0,
                "num_non_errors": 0,
            }

    for raw_result in raw_results:
        completion = raw_result["inference_completion"].strip()
        direction = raw_result["direction"]
        concept_type = raw_result["concept_type"]

        if parse_strict:
            # Need to exactly follow the format (have one uppercase letter as completion)
            if len(completion) != 1:
                errors[direction]["Invalid length"] += 1
            else:
                if direction == "forward":
                    num_options = len(_constants.CONCEPT_TO_SYNTHETIC_MAP[concept_type])
                else:
                    # TODO: hardcoded number of choices
                    num_options = 10
                if not (0 <= (ord(completion) - ord("A")) < num_options):
                    errors["Invalid option"] += 1
                else:
                    results_per_direction[direction][concept_type]["num_non_errors"] += 1
                    results_per_direction[direction][concept_type]["num_correct"] += (
                        completion == raw_result["expected_key"]
                    )
        else:
            # Be more generous: look at first letter and convert to uppercase if needed
            if len(completion) == 0:
                errors[direction]["Empty response"] += 1
            else:
                if direction == "forward":
                    num_options = len(_constants.CONCEPT_TO_SYNTHETIC_MAP[concept_type])
                else:
                    # TODO: hardcoded number of choices
                    num_options = 10

                response_option = completion[0].upper()
                if not (0 <= (ord(response_option) - ord("A")) < num_options):
                    errors["Invalid option"] += 1
                else:
                    results_per_direction[direction][concept_type]["num_non_errors"] += 1
                    results_per_direction[direction][concept_type]["num_correct"] += (
                        completion == raw_result["expected_key"]
                    )

    return {
        **results_per_direction,
        "errors": errors,
    }
