from abc import ABC, abstractmethod

import numpy as np

from datasets.dict import DictDataset
from datasets.numpy import NumpyDataset


class SourceTargetSigmaDataset(DictDataset, ABC):
    def __init__(
            self,
            source_shape: tuple[int, ...],
            target_shape: tuple[int, ...],
            sigma_shape: tuple[int, ...],
            source_type: type,
            target_type: type,
            sigma_type: type = float
    ) -> None:
        assert sigma_shape == (), f'sigma_shape must be (): {sigma_shape}'
        assert issubclass(sigma_type, np.float64), f'sigma_type must be float: {sigma_type}'
        super().__init__(
            {
                'source': source_shape,
                'target': target_shape,
                'sigma': sigma_shape,
            },
            {
                'source': source_type,
                'target': target_type,
                'sigma': sigma_type,
            }
        )

    @abstractmethod
    def source(self, item: int) -> np.ndarray:
        raise NotImplementedError('source must be implemented')

    @abstractmethod
    def target(self, item: int) -> np.ndarray:
        raise NotImplementedError('target must be implemented')

    @abstractmethod
    def sigma(self, item: int) -> np.ndarray:
        raise NotImplementedError('sigma must be implemented')

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {
            'source': self.source(item),
            'target': self.target(item),
            'sigma': self.sigma(item)
        }


class NumpySourceTargetSigmaDataset(SourceTargetSigmaDataset):
    def __init__(
            self,
            source_dataset: NumpyDataset,
            target_dataset: NumpyDataset,
            sigma_dataset: NumpyDataset
    ) -> None:
        super().__init__(
            source_dataset.data_shape,
            target_dataset.data_shape,
            sigma_dataset.data_shape,
            source_dataset.data_type,
            target_dataset.data_type,
            sigma_dataset.data_type
        )
        self.source_dataset: NumpyDataset = source_dataset
        self.target_dataset: NumpyDataset = target_dataset
        self.sigma_dataset: NumpyDataset = sigma_dataset

    def __len__(self) -> int:
        return len(self.source_dataset)

    def source(self, item: int) -> np.ndarray:
        return self.source_dataset[item]

    def target(self, item: int) -> np.ndarray:
        return self.target_dataset[item]

    def sigma(self, item: int) -> np.ndarray:
        return self.sigma_dataset[item]
