import dataclasses
from llm_inference import output_parsers
from llm_inference.metrics import code_eval
import itertools

import datasets as ds
import numpy as np
import pandas as pd
from datasets.arrow_dataset import Dataset

from llm_inference.tasks.task import HFTask


@dataclasses.dataclass
class MBPP(HFTask):
  dataset_path: str = "mbpp"
  dataset_name: str = "sanitized"
  dataset_split: str = "test"

  @property
  def output_keys(self):
    return ["description", "test_list"]

  @property
  def stop_tokens(self):
    return [
      "\nclass",
      "\nassert",
      '\n"""',
      "\nprint",
      "\nif",
      "\n<|/",
      "\n```",
      "<|endoftext|>",
    ]

  def load_dataset(self) -> Dataset:
    if self.dataset_split == "all":
      datasets = [
        ds.load_dataset(self.dataset_path, name=self.dataset_name, split=split)
        for split in ["train", "test"]
      ]
      dataset = ds.concatenate_datasets(datasets)
    else:
      dataset = ds.load_dataset(
        self.dataset_path, name=self.dataset_name, split=self.dataset_split
      )
    return dataset

  def get_reference_solutions(self, example: dict):
    return [example["code"]]

  def get_evaluation_cfg(
    self,
  ):
    from llm_inference import eval_utils

    return eval_utils.EvaluationConfig(
      metric=code_eval.CodeEval(),
      get_reference=create_test_program,
      output_parser=[
        output_parsers.extract_first_function,
      ],
      execution_strategy="thread",
      batch_size=1,
    )


def create_test_program(example):
  return "\n".join(example["test_imports"]) + "\n\n" + "\n".join(example["test_list"])
