from abc import ABC, abstractmethod

import numpy as np

from datasets.numpy import NumpyDataset
from src.datasets.source_target_label_sigma import SourceTargetLabelSigmaDataset
from src.datasets.source_target_sigma import SourceTargetSigmaDataset


class LinesUnconditionalDataset(SourceTargetSigmaDataset, ABC):
    def __init__(
            self,
            image_shape: tuple[int, ...],
            sigma_shape: tuple[int, ...],
            image_type: type,
            sigma_type: type = float
    ) -> None:
        super().__init__(
            image_shape, image_shape, sigma_shape,
            image_type, image_type, sigma_type
        )

    @abstractmethod
    def noise(self, item: int) -> np.ndarray:
        raise NotImplementedError('noise must be implemented')

    @abstractmethod
    def image(self, item: int) -> np.ndarray:
        raise NotImplementedError('image must be implemented')

    def source(self, item: int) -> np.ndarray:
        return self.image(item) + self.sigma(item) * self.noise(item)

    def target(self, item: int) -> np.ndarray:
        return self.image(item)

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        noise: np.ndarray = self.noise(item)
        image: np.ndarray = self.image(item)
        sigma: np.ndarray = self.sigma(item)
        return {
            'source': image + sigma * noise,
            'target': image,
            'sigma': sigma
        }


class LinesConditionalDataset(SourceTargetLabelSigmaDataset, ABC):
    def __init__(
            self,
            image_shape: tuple[int, ...],
            label_shape: tuple[int, ...],
            sigma_shape: tuple[int, ...],
            image_type: type,
            label_type: type = int,
            sigma_type: type = float
    ) -> None:
        super().__init__(
            image_shape, image_shape, label_shape, sigma_shape,
            image_type, image_type, label_type, sigma_type
        )

    @abstractmethod
    def noise(self, item: int) -> np.ndarray:
        raise NotImplementedError('noise must be implemented')

    @abstractmethod
    def image(self, item: int) -> np.ndarray:
        raise NotImplementedError('image must be implemented')

    def source(self, item: int) -> np.ndarray:
        return self.image(item) + self.sigma(item) * self.noise(item)

    def target(self, item: int) -> np.ndarray:
        return self.image(item)

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        noise: np.ndarray = self.noise(item)
        image: np.ndarray = self.image(item)
        label: np.ndarray = self.label(item)
        sigma: np.ndarray = self.sigma(item)
        return {
            'source': image + sigma * noise,
            'target': image,
            'label': label,
            'sigma': sigma
        }


class NumpyLinesUnconditionalDataset(LinesUnconditionalDataset):
    def __init__(
            self,
            noise_dataset: NumpyDataset,
            image_dataset: NumpyDataset,
            sigma_dataset: NumpyDataset
    ) -> None:
        assert noise_dataset.data_shape == image_dataset.data_shape, (
            f'noise dataset shape must be equal to image dataset shape: '
            f'{noise_dataset.data_shape} != {image_dataset.data_shape}'
        )
        assert noise_dataset.data_type == image_dataset.data_type, (
            f'noise dataset type must be equal to image dataset type: '
            f'{noise_dataset.data_type} != {image_dataset.data_type}'
        )
        super().__init__(
            noise_dataset.data_shape, sigma_dataset.data_shape,
            noise_dataset.data_type, sigma_dataset.data_type
        )
        self.noise_dataset: NumpyDataset = noise_dataset
        self.image_dataset: NumpyDataset = image_dataset
        self.sigma_dataset: NumpyDataset = sigma_dataset

    def __len__(self) -> int:
        return len(self.noise_dataset)

    def noise(self, item: int) -> np.ndarray:
        return self.noise_dataset[item]

    def image(self, item: int) -> np.ndarray:
        return self.image_dataset[item]

    def sigma(self, item: int) -> np.ndarray:
        return self.sigma_dataset[item]


class NumpyLinesConditionalDataset(LinesConditionalDataset):
    def __init__(
            self,
            noise_dataset: NumpyDataset,
            image_dataset: NumpyDataset,
            label_dataset: NumpyDataset,
            sigma_dataset: NumpyDataset
    ) -> None:
        assert noise_dataset.data_shape == image_dataset.data_shape, (
            f'noise dataset shape must be equal to image dataset shape: '
            f'{noise_dataset.data_shape} != {image_dataset.data_shape}'
        )
        assert noise_dataset.data_type == image_dataset.data_type, (
            f'noise dataset type must be equal to image dataset type: '
            f'{noise_dataset.data_type} != {image_dataset.data_type}'
        )
        assert len(noise_dataset) == len(image_dataset) == len(label_dataset) == len(sigma_dataset), (
            f'all datasets must have the same length: '
            f'{len(noise_dataset)} != {len(image_dataset)} != {len(label_dataset)} != {len(sigma_dataset)}'
        )
        super().__init__(
            noise_dataset.data_shape, label_dataset.data_shape, sigma_dataset.data_shape,
            noise_dataset.data_type, label_dataset.data_type, sigma_dataset.data_type
        )
        self.noise_dataset: NumpyDataset = noise_dataset
        self.image_dataset: NumpyDataset = image_dataset
        self.label_dataset: NumpyDataset = label_dataset
        self.sigma_dataset: NumpyDataset = sigma_dataset

    def __len__(self) -> int:
        return len(self.noise_dataset)

    def noise(self, item: int) -> np.ndarray:
        return self.noise_dataset[item]

    def image(self, item: int) -> np.ndarray:
        return self.image_dataset[item]

    def label(self, item: int) -> np.ndarray:
        return self.label_dataset[item]

    def sigma(self, item: int) -> np.ndarray:
        return self.sigma_dataset[item]
