import abc
import dataclasses
import typing as tp

import datasets as ds
import numpy as np

T = tp.TypeVar("T")
Factory = tp.Callable[[], T]
Preprocessor = tp.Callable[[ds.Dataset], ds.Dataset]


def _chain_preprocessors(*preprocessors: Preprocessor) -> Preprocessor:
  def _chain(dataset: ds.Dataset) -> ds.Dataset:
    for preprocessor in preprocessors:
      dataset = preprocessor(dataset)
    return dataset

  return _chain


class Metric(abc.ABC):
  def compute(
    self, references: tp.Sequence[str], predictions: tp.Sequence[tp.Sequence[str]]
  ) -> np.ndarray: ...

  def reduce(self, metric_values: tp.Sequence[tp.Any]) -> tp.Any: ...


@dataclasses.dataclass
class Task:
  dataset: Factory[ds.Dataset]
  preprocessor: Preprocessor
  output_parser: tp.Any
  metrics: tp.Mapping[str, Metric]
  stop_tokens: tp.Sequence[str] = ()


@dataclasses.dataclass
class TaskRegistry:
  _REGISTRY: tp.ClassVar[dict[str, Task]] = {}

  @staticmethod
  def define(
    name: str,
    dataset: Factory[ds.Dataset],
    preprocessors: tp.Sequence[Preprocessor] | Preprocessor,
    output_parser: tp.Any,
    metrics: tp.Mapping[str, Metric],
    stop_tokens: tp.Sequence[str] = (),
  ):
    if name in TaskRegistry._REGISTRY:
      raise ValueError(f"Task {name} already defined")
    TaskRegistry._REGISTRY[name] = Task(
      dataset=dataset,
      preprocessor=(
        _chain_preprocessors(*preprocessors)
        if not callable(preprocessors)
        else preprocessors
      ),
      output_parser=output_parser,
      metrics=metrics,
      stop_tokens=stop_tokens,
    )

  def get_task(self, name: str) -> Task:
    if name not in self._REGISTRY:
      raise ValueError(f"Task {name} not found in registry")
    return self._REGISTRY[name]
