from abc import ABC, abstractmethod

import numpy as np

from datasets.dict import DictDataset
from datasets.numpy import NumpyDataset


class SourceTargetLabelSigmaDataset(DictDataset, ABC):
    def __init__(
            self,
            source_shape: tuple[int, ...],
            target_shape: tuple[int, ...],
            label_shape: tuple[int, ...],
            sigma_shape: tuple[int, ...],
            source_type: type,
            target_type: type,
            label_type: type = int,
            sigma_type: type = float
    ) -> None:
        assert label_shape == (), f'label_shape must be (): {label_shape}'
        assert sigma_shape == (), f'sigma_shape must be (): {sigma_shape}'
        assert issubclass(label_type, np.int64), f'label_type must be int: {label_type}'
        assert issubclass(sigma_type, np.float64), f'sigma_type must be float: {sigma_type}'
        super().__init__(
            {
                'source': source_shape,
                'target': target_shape,
                'label': label_shape,
                'sigma': sigma_shape,
            },
            {
                'source': source_type,
                'target': target_type,
                'label': label_type,
                'sigma': sigma_type,
            }
        )

    @abstractmethod
    def source(self, item: int) -> np.ndarray:
        raise NotImplementedError('source must be implemented')

    @abstractmethod
    def target(self, item: int) -> np.ndarray:
        raise NotImplementedError('target must be implemented')

    @abstractmethod
    def label(self, item: int) -> np.ndarray:
        raise NotImplementedError('label must be implemented')

    @abstractmethod
    def sigma(self, item: int) -> np.ndarray:
        raise NotImplementedError('sigma must be implemented')

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {
            'source': self.source(item),
            'target': self.target(item),
            'label': self.label(item),
            'sigma': self.sigma(item)
        }


class NumpySourceTargetLabelSigmaDataset(SourceTargetLabelSigmaDataset):
    def __init__(
            self,
            source_dataset: NumpyDataset,
            target_dataset: NumpyDataset,
            label_dataset: NumpyDataset,
            sigma_dataset: NumpyDataset
    ) -> None:
        assert len(source_dataset) == len(target_dataset) == len(label_dataset) == len(sigma_dataset), (
            f'all datasets must have the same length: '
            f'{len(source_dataset)} != {len(target_dataset)} != {len(label_dataset)} != {len(sigma_dataset)}'
        )
        super().__init__(
            source_dataset.data_shape,
            target_dataset.data_shape,
            label_dataset.data_shape,
            sigma_dataset.data_shape,
            source_dataset.data_type,
            target_dataset.data_type,
            label_dataset.data_type,
            sigma_dataset.data_type
        )
        self.source_dataset: NumpyDataset = source_dataset
        self.target_dataset: NumpyDataset = target_dataset
        self.label_dataset: NumpyDataset = label_dataset
        self.sigma_dataset: NumpyDataset = sigma_dataset

    def __len__(self) -> int:
        return len(self.source_dataset)

    def source(self, item: int) -> np.ndarray:
        return self.source_dataset[item]

    def target(self, item: int) -> np.ndarray:
        return self.target_dataset[item]

    def label(self, item: int) -> np.ndarray:
        return self.label_dataset[item]

    def sigma(self, item: int) -> np.ndarray:
        return self.sigma_dataset[item]
