import json
import logging
import os
import string
import re
import warnings


from collections import defaultdict
from datetime import datetime
from typing import List


class GAIA:
    """
    This class implements GAIA benchmark, with methods for
    scoring a model answers.
    """

    def __init__(
        self,
        path_to_questions: str,
    ) -> None:
        """
        Parameters:
        - path_to_questions: str, benchmark questions, with ground truth.
        """

        self.path_to_questions = path_to_questions

    def score_model(
        self,
        path_to_model_answers: str,
        path_to_model_scored: str,
    ) -> dict:
        """
        Score a model on GAIA.

        Parameters:
        - model_name: str, the model to score
        """

        self.populate_with_model_answers(
            path_to_model_answers,
            path_to_model_scored,
        )
        self.populate_with_scores(
            path_to_model_scored,
        )
        return self.get_scores(path_to_model_scored)

    def populate_with_model_answers(
        self,
        path_to_model_answers: str,
        path_to_model_scored: str,
    ) -> None:
        """
        This method add model answers and other metadata to a jsonl of
        questions for further processing.

        Parameters:
        - path_to_model_answers: str, path to jsonl containing at least a
        'task_id' entry and corresponding 'model_answer' entry
        - path_to_model_scored: str, path for output
        """

        # loading questions
        with open(self.path_to_questions + "/metadata.jsonl", "r") as f:
            questions = []
            for line in f:
                questions.append(json.loads(line.strip()))

        # loading model answers
        with open(path_to_model_answers, "r") as f:
            model_answers = []
            for line in f:
                model_answers.append(json.loads(line.strip()))

        model_answers_dict = {
            item["task_id"]: item["model_answer"] for item in model_answers
        }

        # adding model answer to question dict
        for question in questions:
            assert (
                question["task_id"] in model_answers_dict
            ), "task_id not found in model answers"
            question["model_answer"] = model_answers_dict[question["task_id"]]

        # write new file with added model answers
        with open(path_to_model_scored, "w") as file:
            for question in questions:
                json_question = json.dumps(question)
                file.write(json_question + "\n")

    def populate_with_scores(
        self,
        path_to_model_scored: str,
    ) -> None:
        """
        This method add scores to the input file containing questions and
        model answers.

        - path_to_model_scored: str, path to questions with model answers
        and ground truth
        """
        # configuring logging
        path_to_log = path_to_model_scored.split(".")[0]
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(
                    f"{path_to_log}.log", mode="w", encoding=None, delay=False
                ),
                # logging.StreamHandler(),
            ],
        )

        # loading questions model answers
        with open(path_to_model_scored, "r") as f:
            questions_with_model_answers = []
            for line in f:
                questions_with_model_answers.append(json.loads(line.strip()))

        # scoring and logging for debugging
        for question in questions_with_model_answers:
            logging.info("=" * 100)
            assert "model_answer" in question, "no model_answer"
            question["score"] = (
                1
                if GAIA.score(question["model_answer"], question["Final answer"])
                else 0
            )
            logging.info(
                str(question["task_id"])
                + " "
                + str(question["model_answer"])
                + " "
                + str(question["Final answer"])
                + " "
                + str(question["score"])
            )

        # write results
        with open(path_to_model_scored, "w") as file:
            for question_ma in questions_with_model_answers:
                json_question = json.dumps(question_ma)
                file.write(json_question + "\n")

    def get_scores(self, path_to_model_scored: str) -> dict:
        nb_correct_per_lvl = defaultdict(int)
        nb_questions_per_lvl = defaultdict(int)

        with open(path_to_model_scored, "r") as f:
            for line in f:
                question = json.loads(line.strip())
                nb_questions_per_lvl[question["Level"]] += 1
                if question["score"]:
                    nb_correct_per_lvl[question["Level"]] += 1

        return {
            lvl: nb_correct_per_lvl[lvl] / nb_questions_per_lvl[lvl]
            for lvl in nb_correct_per_lvl.keys()
        }

    @staticmethod
    def normalize_number_str(number_str: str) -> float:
        # we replace these common units and commas to allow
        # conversion to float
        for char in ["$", "%", ","]:
            number_str = number_str.replace(char, "")
        try:
            return float(number_str)
        except ValueError:
            # print(f"String {number_str} cannot be normalized to number str.")
            return float("inf")

    @staticmethod
    def split_string(
        s: str,
        char_list: List[str] = [",", ";"],
    ) -> List[str]:
        pattern = f"[{''.join(char_list)}]"
        return re.split(pattern, s)

    @staticmethod
    def score(
        model_answer: str,
        ground_truth: str,
    ) -> bool:
        def is_float(element: any) -> bool:
            try:
                float(element)
                return True
            except ValueError:
                return False

        # if gt is a number
        if is_float(ground_truth):
            # print(f"Evaluating {model_answer} as a number.")
            normalized_answer = GAIA.normalize_number_str(model_answer)
            return normalized_answer == float(ground_truth)

        # if gt is a list
        elif any(char in ground_truth for char in [",", ";"]):
            # print(f"Evaluating {model_answer} as a comma separated list.")
            # question with the fish: normalization removes punct

            gt_elems = GAIA.split_string(ground_truth)
            ma_elems = GAIA.split_string(model_answer)

            # check length is the same
            if len(gt_elems) != len(ma_elems):
                warnings.warn(
                    "Answer lists have different lengths, returning False.", UserWarning
                )
                return False

            # compare each element as float or str
            comparisons = []
            for ma_elem, gt_elem in zip(ma_elems, gt_elems):
                if is_float(gt_elem):
                    normalized_ma_elem = GAIA.normalize_number_str(ma_elem)
                    comparisons.append(normalized_ma_elem == float(gt_elem))
                else:
                    # we do not remove punct since comparisons can include punct
                    comparisons.append(
                        GAIA.normalize_str(ma_elem, remove_punct=False)
                        == GAIA.normalize_str(gt_elem, remove_punct=False)
                    )
            return all(comparisons)

        # if gt is a str
        else:
            # print(f"Evaluating {model_answer} as a string.")
            return GAIA.normalize_str(model_answer) == GAIA.normalize_str(ground_truth)

    @staticmethod
    def normalize_str(input_str, remove_punct=True) -> str:
        """
        Normalize a string by:
        - Removing all white spaces
        - Optionally removing punctuation (if remove_punct is True)
        - Converting to lowercase

        Parameters:
        - input_str: str, the string to normalize
        - remove_punct: bool, whether to remove punctuation (default: True)

        Returns:
        - str, the normalized string
        """
        # Remove all white spaces. Required e.g for seagull vs. sea gull
        no_spaces = re.sub(r"\s", "", input_str)

        # Remove punctuation, if specified.
        if remove_punct:
            translator = str.maketrans("", "", string.punctuation)
            return no_spaces.lower().translate(translator)
        else:
            return no_spaces.lower()

    @staticmethod
    def extract_answer(
        input_str: str,
        prompt_sep: str,
    ) -> str:
        answer = input_str.split(prompt_sep)[-1].strip()
        return answer

    def make_model_answers(
        self,
        model,
        path_to_model_answers: str,
        prompt_sep: str = "FINAL ANSWER:",
    ) -> dict:
        """
        In-house script to produce answers for a model.
        """

        print(f"Starting {str(model)} evaluation.")

        with open(self.path_to_questions + "/metadata.jsonl", "r") as f:
            questions = []
            for line in f:
                questions.append(json.loads(line.strip()))

        model_answers = []
        for question in questions:
            # print(question, os.path.join(self.path_to_questions + question["file_name"]))
            raw_model_answer = model(
                question=question["Question"], file_path=os.path.join(self.path_to_questions, question["file_name"])
            )
            print(raw_model_answer)
            model_answer = GAIA.extract_answer(raw_model_answer, prompt_sep)
            print(model_answer)
            model_answer = {
                "task_id": question["task_id"],
                "model_answer": model_answer,
                "raw_model_answer": raw_model_answer,
                "timestamp": datetime.today(),
                "prefix": model.prefix,
            }
            model_answers.append(model_answer)

        with open(path_to_model_answers, "w") as file:
            for model_answer in model_answers:
                # we use str by default since datetime is not serializable
                json_model_answer = json.dumps(model_answer, default=str)
                file.write(json_model_answer + "\n")

        return
