from abc import ABC, abstractmethod

import numpy as np

from datasets.dict import DictDataset
from datasets.numpy import NumpyDataset


class NoiseLabelDataset(DictDataset, ABC):
    def __init__(
            self,
            noise_shape: tuple[int, ...],
            label_shape: tuple[int, ...],
            noise_type: type,
            label_type: type = int,
    ) -> None:
        assert label_shape == (), f'label_shape must be (): {label_shape}'
        assert issubclass(label_type, np.int64), f'label_type must be int: {label_type}'
        super().__init__(
            {
                'noise': noise_shape,
                'label': label_shape
            },
            {
                'noise': noise_type,
                'label': label_type
            }
        )

    @abstractmethod
    def noise(self, item: int) -> np.ndarray:
        raise NotImplementedError('source must be implemented')

    @abstractmethod
    def label(self, item: int) -> np.ndarray:
        raise NotImplementedError('label must be implemented')

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {
            'noise': self.noise(item),
            'label': self.label(item)
        }


class NumpyNoiseLabelDataset(NoiseLabelDataset):
    def __init__(
            self,
            noise_dataset: NumpyDataset,
            label_dataset: NumpyDataset
    ) -> None:
        assert len(noise_dataset) == len(label_dataset), \
            f'noise and label must have the same number of samples: {len(noise_dataset)} != {len(label_dataset)}'
        super().__init__(
            noise_dataset.data_shape,
            label_dataset.data_shape,
            noise_dataset.data_type,
            label_dataset.data_type
        )
        self.noise_dataset: NumpyDataset = noise_dataset
        self.label_dataset: NumpyDataset = label_dataset

    def __len__(self) -> int:
        return len(self.noise_dataset)

    def noise(self, item: int) -> np.ndarray:
        return self.noise_dataset[item]

    def label(self, item: int) -> np.ndarray:
        return self.label_dataset[item]
