# based on the code from https://github.com/bigcode-project/bigcode-evaluation-harness
from datasets import load_dataset
from transformers import StoppingCriteria, StoppingCriteriaList
from code_eval import compute_code_eval


class HumanEvalTask():
    def __init__(self, new_line):
        self.dataset_path = "openai_humaneval"
        self.stop_words = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif", "\n```", "<file_sep>"]
        self.k = [1, 10, 100]
        self.num_workers = 16
        self.timeout = 3.0
        self.dataset = load_dataset(path=self.dataset_path, split="test")
        self.requires_execution = True
        self.new_line = new_line

    @staticmethod
    def _stop_at_stop_token(decoded_string, stop_tokens):
        min_stop_index = len(decoded_string)
        for stop_token in stop_tokens:
            stop_index = decoded_string.find(stop_token)
            if stop_index != -1 and stop_index < min_stop_index:
                min_stop_index = stop_index
        return decoded_string[:min_stop_index]

    def get_dataset(self,):
        return self.dataset

    def get_prompt(self, example):
        prompt = example["prompt"].strip()
        if self.new_line:
            prompt += "\n"
        return prompt

    def get_reference(self, example):
        test_func = example["test"]
        entry_point = f"check({example['entry_point']})"
        return "\n" + test_func + "\n" + entry_point

    def postprocess_generation(self, generation, idx):
        prompt = self.get_prompt(self.dataset[idx])
        generation = generation[len(prompt) :]
        return prompt + self._stop_at_stop_token(generation, self.stop_words)

    def process_results(self, generations, references, return_errs=False):
        results, results_ext = compute_code_eval(
            references=references,
            predictions=generations,
            k=self.k,
            num_workers=self.num_workers,
            timeout=self.timeout,
        )
        if return_errs:
            return results, results_ext
        return results


class MBPPTask():
    def __init__(self, new_line, two_shot=False):
        self.dataset_path = "mbpp"
        self.stop_words = ["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/", "\n```"]
        self.requires_execution = True
        self.dataset = load_dataset(path=self.dataset_path, split="test")
        self.new_line = new_line
        self.two_shot = two_shot

    @staticmethod
    def _stop_at_stop_token(decoded_string, stop_tokens):
        min_stop_index = len(decoded_string)
        for stop_token in stop_tokens:
            stop_index = decoded_string.find(stop_token)
            if stop_index != -1 and stop_index < min_stop_index:
                min_stop_index = stop_index
        return decoded_string[:min_stop_index]

    def get_dataset(self):
        return self.dataset

    def get_one_shot(self,):
        prompt = "Write a function to find the similar elements from the given two tuple lists."
        tests = ["assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)"]
        tests = "\n".join(tests)
        code = "def similar_elements(test_tup1, test_tup2):\r\n  res = tuple(set(test_tup1) & set(test_tup2))\r\n  return (res) "
        code = code.replace("\r", "").replace("\t", "    ")
        prompt1 = f"You are an expert Python programmer, and here is your task: {prompt} Your code should pass these tests:\n\n{tests}\n[BEGIN]\n{code}\n[DONE]\n"
        return prompt1

    def get_two_shot(self,):
        prompt = "Write a function to find the similar elements from the given two tuple lists."
        tests = ["assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)"]
        tests = "\n".join(tests)
        code = "def similar_elements(test_tup1, test_tup2):\r\n  res = tuple(set(test_tup1) & set(test_tup2))\r\n  return (res) "
        code = code.replace("\r", "").replace("\t", "    ")
        prompt1 = f"You are an expert Python programmer, and here is your task: {prompt} Your code should pass these tests:\n\n{tests}\n[BEGIN]\n{code}\n[DONE]\n"
        
        prompt = 'Write a python function to identify non-prime numbers.'
        tests = ['assert is_not_prime(2) == False', 'assert is_not_prime(10) == True', 'assert is_not_prime(35) == True']
        tests = "\n".join(tests)
        code = 'import math\r\ndef is_not_prime(n):\r\n    result = False\r\n    for i in range(2,int(math.sqrt(n)) + 1):\r\n        if n % i == 0:\r\n            result = True\r\n    return result'
        code = code.replace("\r", "").replace("\t", "    ")
        prompt2 = f"You are an expert Python programmer, and here is your task: {prompt} Your code should pass these tests:\n\n{tests}\n[BEGIN]\n{code}\n[DONE]\n"
        
        return '{}\n{}'.format(prompt1, prompt2)
        
    def get_prompt(self, doc):
        prompt1 = doc["text"]
        tests = "\n".join(doc["test_list"])
        if doc["task_id"] == 493:
            tests = "\n".join(doc["test_list"][-1:])
        prompt = f"You are an expert Python programmer, and here is your task: {prompt1} Your code should pass these tests:\n\n{tests}\n[BEGIN]"
        if self.new_line:
            prompt += "\n"
        few_shot = self.get_one_shot()
        if self.two_shot:
            few_shot = self.get_two_shot()
        prompt = few_shot + prompt
        return prompt

    def get_reference(self, doc):
        return "\n".join(doc["test_list"])

    def get_valid_output(self, generation, idx):
        prompt = self.get_prompt(self.dataset[idx])
        generation = generation[len(prompt) :]
        token_idx = generation.find("[DONE]")
        if token_idx != -1:
            generation = generation[: token_idx + len("[DONE]")]
        return prompt + generation

    def postprocess_generation(self, generation, idx):
        prompt = self.get_prompt(self.dataset[idx])
        generation = generation[len(prompt) :]
        generation = generation.split("[DONE]")[0].strip()
        # print (generation)
        return generation

    def process_results(self, generations, references, return_errs=False):
        results, results_ext = compute_code_eval(
            references=references,
            predictions=generations,
        )
        if return_errs:
            return results, results_ext
        return results


class EndOfFunctionCriteria(StoppingCriteria):
    """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
    def __init__(self, start_length, eof_strings, tokenizer, check_fn=None):
        self.start_length = start_length
        self.eof_strings = eof_strings
        self.tokenizer = tokenizer
        if check_fn is None:
            check_fn = lambda decoded_generation: any(
                [stop_string in decoded_generation for stop_string in self.eof_strings]
            )
        self.check_fn = check_fn

    def __call__(self, input_ids, scores, **kwargs):
        """Returns true if all generated sequences contain any of the end-of-function strings."""
        decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
        return all([self.check_fn(decoded_generation) for decoded_generation in decoded_generations])


class TooLongFunctionCriteria(StoppingCriteria):
    """Custom `StoppingCriteria` which checks if the generated function is too long by a certain multiplier based on input length."""

    def __init__(self, input_length, multiplier):
        self.input_length = input_length
        self.multiplier = multiplier

    def __call__(self, input_ids, scores, **kwargs):
        """Returns true if generated sequence is too long."""
        return input_ids.shape[1] > int(self.input_length * self.multiplier)

