import transformers
import re
import numpy as np
from typing import List

"""
Copied from: https://github.com/EleutherAI/lm-evaluation-harness/blob/2f03271d25db3c19e5552e19f59816bcbba07357/lm_eval/models/utils.py#L216
"""
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence."""

    def __init__(
        self,
        sequence: str,
        tokenizer: transformers.PreTrainedTokenizer,
        initial_decoder_input_length: int,
        batch_size: int,
    ) -> None:
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
        # print(sequence, self.sequence_ids)
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization

        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
        # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
        self.sequence_id_len = len(self.sequence_ids) + 2
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]

        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]

        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)

        for i, done in enumerate(self.done_tracker):
            if not done:
                self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
        return False not in self.done_tracker

"""
Copied from: https://github.com/EleutherAI/lm-evaluation-harness/blob/2f03271d25db3c19e5552e19f59816bcbba07357/lm_eval/models/utils.py#L256
"""
def stop_sequences_criteria(
    tokenizer: transformers.PreTrainedTokenizer,
    stop_sequences: List[str],
    initial_decoder_input_length: int, # pass prompt_input_ids.shape[1] here
    batch_size: int,
) -> transformers.StoppingCriteriaList:
    return transformers.StoppingCriteriaList(
        [
            *[
                MultiTokenEOSCriteria(
                    sequence, tokenizer, initial_decoder_input_length, batch_size
                )
                for sequence in stop_sequences
            ],
        ]
    )

GSM8K_STOP_SEQUENCES = ['<|eot_id|>', '<|start_header_id|>user<|end_header_id|>', 'Q:', '</s>', '<|im_end|>']

class GSM8KParser():
    """
    A parser class for extracting numeric answers from model-generated responses using regular expressions.
    
    Inspired by:
    (0) https://github.com/EleutherAI/lm-evaluation-harness/blob/a96085f1515fe87af44350b01094bee248515356/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml#L45
    (1) https://github.com/EleutherAI/lm-evaluation-harness/blob/2f03271d25db3c19e5552e19f59816bcbba07357/lm_eval/filters/extraction.py#L10

    Parameters:
        regex_pattern (str): The regular expression pattern used to extract the final numeric answer.
        group_select (int): Index of the matched group to select from regex matches. Use -1 for last match.
        fallback (str): Fallback value to return if no match is found or if match group is empty.
    """
    def __init__(self, 
        regex_pattern=r"The final answer is ((-?[$0-9.,]{2,})|(-?[0-9]+))",
        group_select=-1,
        fallback="[invalid]"):
        """
        Default values are taken from the lm-eval gsm8k-cot-llama.yaml config.
        """
        
        self.regex_pattern = regex_pattern
        self.regex = re.compile(regex_pattern)
        self.group_select = group_select
        self.fallback = fallback
    
    def _call(self, response: str) -> str:
        """
        Parses a single response string to extract the final numeric answer using the regex pattern.

        Parameters:
            response (str): The raw text response generated by the model.

        Returns:
            str: Extracted and cleaned final answer as a string. If no match is found, returns the fallback value.
        """
        match = self.regex.findall(response)
        if match:
            match = match[self.group_select]
            if isinstance(match, tuple):
                match = [m for m in match if m]
                if match:
                    match = match[0]
                else:
                    match = self.fallback
            match = match.strip()
        else:
            match = self.fallback
        
        return match

    def __call__(self, responses: list[str]):
        """
        Parses a batch of responses and extracts final answers from each.

        Parameters:
            responses (list of str): A list of raw response strings.

        Returns:
           
            list of str: A list of parsed answers, one for each response.
        """
        return [self._call(resp) for resp in responses]

class GSM8KExactMatcher():
    """
    A metric class to compute exact match accuracy between predicted and reference answers
    with optional normalization (regex removal, lowercasing, punctuation/number stripping).
    
    Inspired by:
    (0) https://github.com/EleutherAI/lm-evaluation-harness/blob/a96085f1515fe87af44350b01094bee248515356/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml#L70
    (1) https://github.com/EleutherAI/lm-evaluation-harness/blob/2f03271d25db3c19e5552e19f59816bcbba07357/lm_eval/api/metrics.py#L197

    Parameters:
        regexes_to_ignore (list of str): List of regex patterns to remove from both predictions and references.
        ignore_case (bool): If True, lowercases all strings before comparison.
        ignore_punctuation (bool): If True, removes punctuation before comparison.
        ignore_numbers (bool): If True, removes digits before comparison.
    """
    def __init__(self,
        regexes_to_ignore=[',', '\\$', '(?s).*#### ', '\\.$'],
        ignore_case=True,
        ignore_punctuation=False,
        ignore_numbers=False):
        """
        Default values are taken from the lm-eval gsm8k-cot-llama.yaml config.
        """
        
        self.regexes_to_ignore = regexes_to_ignore
        self.ignore_case = ignore_case
        self.ignore_punctuation = ignore_punctuation
        self.ignore_numbers = ignore_numbers

    def __call__(self, predictions, references):  
        """
        Computes the exact match accuracy between predicted and reference answers
        after applying optional normalization steps.

        Parameters:
            predictions (list of str): Model-generated answers.
            references (list of str): Ground truth answers.

        Returns:
            float: Proportion of exact matches after preprocessing.
        """
        regexes_to_ignore = self.regexes_to_ignore
        ignore_case = self.ignore_case
        ignore_punctuation = self.ignore_punctuation
        ignore_numbers = self.ignore_numbers

        if regexes_to_ignore is not None:
            for s in regexes_to_ignore:
                predictions = np.array([re.sub(s, "", x) for x in predictions])
                references = np.array([re.sub(s, "", x) for x in references])
        else:
            predictions = np.asarray(predictions)
            references = np.asarray(references)

        if ignore_case:
            predictions = np.char.lower(predictions)
            references = np.char.lower(references)

        if ignore_punctuation:
            repl_table = string.punctuation.maketrans("", "", string.punctuation)
            predictions = np.char.translate(predictions, table=repl_table)
            references = np.char.translate(references, table=repl_table)

        if ignore_numbers:
            repl_table = string.digits.maketrans("", "", string.digits)
            predictions = np.char.translate(predictions, table=repl_table)
            references = np.char.translate(references, table=repl_table)

        score_list = predictions == references

        return np.mean(score_list)

class GSM8KEvaluator():
    def __init__(self):
        self.parser = GSM8KParser()
        self.matcher = GSM8KExactMatcher()

    def __call__(self, generations, references):
        predictions = self.parser(generations)
        score = self.matcher(predictions, references)
        return score