import abc
import dataclasses
import typing as tp

import datasets as ds

from llm_inference.eval_utils import EvaluationConfig


@dataclasses.dataclass
class Task(abc.ABC):
  """A task represents a dataset and associated evaluation metrics.

  This class encapsulates all task specific logic, such as loading and preprocessing
  the dataset, getting the instruction and reference solution from an example,
  and evaluating predictions."""

  @property
  @abc.abstractmethod
  def output_keys(self) -> list[str]: ...

  @property
  @abc.abstractmethod
  def stop_tokens(self) -> list[str]: ...

  @abc.abstractmethod
  def load_dataset(self, **kwargs) -> ds.Dataset: ...

  def get_reference_solution(self, example: dict) -> str:
    return self.get_reference_solutions(example)[0]

  @abc.abstractmethod
  def get_reference_solutions(self, example: dict) -> list[str]: ...

  @abc.abstractmethod
  def get_evaluation_cfg(
    self,
  ) -> EvaluationConfig:
    """Load the evaluation config for this task.

    Returns:
        [eval_results, aux] where eval_results is a dictionary of evaluation metrics
        and aux is an optional dictionary of auxiliary information.
    """
    ...


@dataclasses.dataclass(kw_only=True)
class HFTask(Task, abc.ABC):
  dataset_path: str
  dataset_name: str | None = None
  dataset_split: str | None = None

  def preprocess(self, dataset: ds.Dataset) -> ds.Dataset:
    return dataset

  def load_dataset(self, **kwargs) -> ds.Dataset:
    dataset = ds.load_dataset(
      self.dataset_path,
      name=self.dataset_name,
      split=self.dataset_split,
      **kwargs,
    )
    if not isinstance(dataset, ds.Dataset):
      raise ValueError(f"Task must return a Dataset, but got {type(dataset)}")

    return self.preprocess(dataset)
