from abc import ABC, abstractmethod

import numpy as np

from datasets.dict import DictDataset
from datasets.numpy import NumpyDataset


class ImageDataset(DictDataset, ABC):
    def __init__(
            self,
            image_shape: tuple[int, ...],
            image_type: type
    ) -> None:
        super().__init__(
            {'image': image_shape},
            {'image': image_type}
        )

    @abstractmethod
    def image(self, item: int) -> np.ndarray:
        raise NotImplementedError('source must be implemented')

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {'image': self.image(item)}


class NumpyImageDataset(ImageDataset):
    def __init__(self, image_dataset: NumpyDataset) -> None:
        super().__init__(
            image_dataset.data_shape,
            image_dataset.data_type
        )
        self.image_dataset: NumpyDataset = image_dataset

    def __len__(self) -> int:
        return len(self.image_dataset)

    def image(self, item: int) -> np.ndarray:
        return self.image_dataset[item]
