import random
from copy import deepcopy
from typing import Callable, Generator, Optional, Type

from serde import Serializable

from .task import Task, TaskInstance


class HuggingFaceTask(Task):
    """A task based on a HuggingFace dataset."""

    DATASET_LINK: Optional[str] = None
    TASK_INSTANCE_CLASS: Optional[Type[TaskInstance]] = None

    def __init__(
        self,
        instance_initializer: Callable[..., TaskInstance],
        size: Optional[int],
        **dataset_kwargs,
    ):
        self.size = size
        self.dataset_kwargs = dataset_kwargs
        self.instance_initializer = instance_initializer
        self._dataset = None

    @property
    def dataset(self):
        if self._dataset is None:
            self.load_dataset()
        assert self._dataset is not None
        return self._dataset

    def split(self, split_factor: float) -> "HuggingFaceTask":
        """Split this dataset into two."""
        if self.size is None:
            size = len(self.dataset)
        else:
            size = self.size
        new_task_size = round(size * split_factor)
        new_self_size = size - new_task_size
        new_task = deepcopy(self)

        self._dataset = self.dataset.select(range(new_self_size))
        new_task._dataset = new_task.dataset.select(range(new_self_size, size))

        new_task.size = new_task_size
        self.size = new_self_size
        return new_task

    def load_dataset(self):
        from datasets import Dataset, load_dataset

        assert self.__class__.DATASET_LINK is not None

        _dataset = load_dataset(self.__class__.DATASET_LINK, **self.dataset_kwargs)
        assert isinstance(_dataset, Dataset)
        self._dataset = (
            _dataset if self.size is None else _dataset.select(range(self.size))
        )

    def __iter__(self) -> Generator[TaskInstance, None, None]:
        """Iterate over the task instances for this task."""
        return (
            self.instance_initializer(
                data,
            )
            for data in self.dataset
        )

    def __getitem__(self, idx: int) -> TaskInstance:
        """Get the item at a particular index."""
        return self.instance_initializer(
            self.dataset[idx],
        )

    def get_random_instance(self) -> TaskInstance:
        return self.instance_initializer(
            self.dataset[random.randint(0, len(self.dataset) - 1)],
        )

    def serialize(self) -> dict[str, Serializable]:
        return {
            "size": self.size,
        }
