
import numpy as np
from evalplus.data import get_mbpp_plus

from inference_rlhf.code.query_builders.factory import query_builder_factory

QUESTION_FORMAT = """
```
{question}
```
"""
ANSWER_FORMAT = """
Below is a Python script with a self-contained function that solves the problem and passes corresponding tests:
```python
{answer}
"""
SEP = ""
QUESTION_EXAMPLES = []
ANSWER_EXAMPLES = []

class MBPPDataLoader():
    """
    DataLoader for the MBPP dataset.
    """
    def __init__(self, cfg):
        dataset_dict = get_mbpp_plus(version="default")

        self.questions = [dataset_dict[problem_key]['prompt'].strip() for problem_key in dataset_dict]
        self.idx_to_task_idx = {i: task_idx for i, task_idx in enumerate(dataset_dict)}
        self.entrypoints = [dataset_dict[problem_key]['entry_point'] for problem_key in dataset_dict]

        self.qb = query_builder_factory(
            cfg.policy.name,
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,
            question_format=QUESTION_FORMAT,
            answer_format=ANSWER_FORMAT,
            sep=SEP,
            question_examples=QUESTION_EXAMPLES,
            answer_examples=ANSWER_EXAMPLES,
        )

        self.num = len(self.questions) if cfg.task.max_samples <= 0 else min(cfg.task.max_samples, len(self.questions))
        self.idxs = self.get_idxs()

    def get_idxs(self):
        if self.num == len(self.questions):
            return list(range(self.num))
        return sorted(np.random.choice(len(self.questions), self.num, replace=False))

    def build_queries(self, apply_chat_template: bool = True): 
        return [self.qb.build_query(self.questions[idx], apply_chat_template=apply_chat_template) for idx in self.idxs]
        
