from abc import ABC, abstractmethod

import numpy as np

from datasets.dict import DictDataset
from datasets.numpy import NumpyDataset


class NoiseDataset(DictDataset, ABC):
    def __init__(
            self,
            noise_shape: tuple[int, ...],
            noise_type: type
    ) -> None:
        super().__init__(
            {'noise': noise_shape},
            {'noise': noise_type}
        )

    @abstractmethod
    def noise(self, item: int) -> np.ndarray:
        raise NotImplementedError('source must be implemented')

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {'noise': self.noise(item)}


class NumpyNoiseDataset(NoiseDataset):
    def __init__(self, noise_dataset: NumpyDataset) -> None:
        super().__init__(
            noise_dataset.data_shape,
            noise_dataset.data_type
        )
        self.noise_dataset: NumpyDataset = noise_dataset

    def __len__(self) -> int:
        return len(self.noise_dataset)

    def noise(self, item: int) -> np.ndarray:
        return self.noise_dataset[item]
