from abc import ABC, abstractmethod

import numpy as np

from datasets.dict import DictDataset
from datasets.numpy import NumpyDataset


class ImageLabelDataset(DictDataset, ABC):
    def __init__(
            self,
            image_shape: tuple[int, ...],
            label_shape: tuple[int, ...],
            image_type: type,
            label_type: type = int,
    ) -> None:
        assert label_shape == (), f'label_shape must be (): {label_shape}'
        assert issubclass(label_type, np.int64), f'label_type must be int: {label_type}'
        super().__init__(
            {
                'image': image_shape,
                'label': label_shape
            },
            {
                'image': image_type,
                'label': label_type
            }
        )

    @abstractmethod
    def image(self, item: int) -> np.ndarray:
        raise NotImplementedError('source must be implemented')

    @abstractmethod
    def label(self, item: int) -> np.ndarray:
        raise NotImplementedError('label must be implemented')

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {
            'image': self.image(item),
            'label': self.label(item)
        }


class NumpyImageLabelDataset(ImageLabelDataset):
    def __init__(
            self,
            image_dataset: NumpyDataset,
            label_dataset: NumpyDataset
    ) -> None:
        assert len(image_dataset) == len(label_dataset), \
            f'image and label must have the same number of samples: {len(image_dataset)} != {len(label_dataset)}'
        super().__init__(
            image_dataset.data_shape,
            label_dataset.data_shape,
            image_dataset.data_type,
            label_dataset.data_type
        )
        self.image_dataset: NumpyDataset = image_dataset
        self.label_dataset: NumpyDataset = label_dataset

    def __len__(self) -> int:
        return len(self.image_dataset)

    def image(self, item: int) -> np.ndarray:
        return self.image_dataset[item]

    def label(self, item: int) -> np.ndarray:
        return self.label_dataset[item]
