"""Program Synthesis with Large Language Models
XXXX

The benchmark consists of around 1,000 crowd-sourced Python programming problems, 
designed to be solvable by entry level programmers, covering programming fundamentals, 
standard library functionality, and so on. Each problem consists of a task description, 
code solution and 3 automated test cases. As described in the paper, a subset of the data
has been hand-verified by the authors.

Homepage:: XXXX
"""

import re

import torch
from evaluate import load

from lm_eval.base import Task

_CITATION = """
@article{austin2021program,
  title={Program Synthesis with Large Language Models},
  author={Austin, Jacob and Odena, Augustus and Nye, Maxwell and Bosma, Maarten and Michalewski, Henryk and Dohan, David and Jiang, Ellen and Cai, Carrie and Terry, Michael and Le, Quoc and others},
  journal={arXiv preprint arXiv:2108.07732},
  year={2021}
}
"""


class MBPP(Task):
    """A task represents an entire benchmark including its dataset, problems,
    answers, generation settings and evaluation methods.
    """

    DATASET_PATH = "mbpp"

    def __init__(self):
        self.rng = torch.Generator()
        self.few_shot = 3
        super().__init__(
            stop_words=["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/"],
            requires_execution=True,
        )

    def get_dataset(self):
        """Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
        dataset = self.dataset["test"]
        # the wrong split of mbpp can be loaded with old datasets cache
        assert (
            len(dataset) == 500
        ), "please ensure you have the latest version of MBPP dataset, try deleting its old cache"
        return dataset

    def get_prompt(self, doc):
        """Builds the prompt for the LM to generate from.
        MBPP prompt is built following to InCoder (Fried et al.) approach
        prompt = docstring that includes one test
        """

        prompt = ""

        # prompting
        self.rng.manual_seed(doc['task_id'])
        train_idx_permutation = torch.randperm(
            len(self.dataset['train']), generator=self.rng
        )[:self.few_shot].tolist()

        for train_idx in train_idx_permutation:
            train_data = self.dataset['train'][train_idx]
            prompt += f'"""\n{train_data["text"]} Your code should satisfy these tests:\n'
            prompt += '\n'.join(train_data['test_list']) + '\n"""\n'
            prompt += train_data['code'] + "\n"

        prompt += f'"""\n{doc["text"]} Your code should satisfy these tests:\n'
        prompt += '\n'.join(doc['test_list']) + '\n"""\n' \
            if doc['task_id'] != 493 \
            else '\n'.join(doc['test_list'][2:3]) + '\n"""\n'

        return prompt

    def get_solutions(self, doc):
        return doc['code']
    
    def get_full_data(self, doc):
        return self.get_prompt(doc) + self.get_solutions(doc)

    def get_reference(self, doc):
        """Builds the reference solution for the doc (sample from the test dataset)."""
        return "\n".join(doc["test_list"])

    @staticmethod
    def first_block(string, stop_words):
        """Split off first block of code by scanning for class, def etc. on newlines."""
        return re.split("|".join(stop_words), string)[0].rstrip()

    def postprocess_generation(self, generation, idx):
        """Defines the postprocessing for a LM generation.
        :param generation: str
            code generation from LM
        :param idx: int
            index of doc in the dataset to which the generation belongs
        """
        prompt = self.get_prompt(self.get_dataset()[idx])
        output = generation[len(prompt) :]
        return prompt + self.first_block(output, self.stop_words)

    def process_results(self, generations, references):
        """Takes the list of LM generations and evaluates them against ground truth references,
        returning the metric for the generations.
        :param generations: list(list(str))
            list of lists containing generations
        :param references: list(str)
            list of str containing refrences
        """
        code_metric = load("code_eval")
        results, pass_info = code_metric.compute(
            references=references,
            predictions=generations,
            num_workers=1,
        )
        return results, pass_info
