from abc import ABC
from typing import Iterable

import numpy as np

from utils.types import JSON
from utils.utils import get_class_name
from .datasets import Dataset, ConcatDataset


class NumpyDataset(Dataset[np.ndarray], ABC):
    def __init__(self, data_shape: tuple[int, ...], data_type: type) -> None:
        assert isinstance(data_shape, tuple), f'data_shape must be tuple: {data_shape}'
        assert all(isinstance(val, int) for val in data_shape), f'data_shape values must be int: {data_shape}'
        assert isinstance(data_type, type), f'data_type must be type: {data_type}'
        self.data_shape: tuple[int, ...] = data_shape
        self.data_type: type = data_type

    @classmethod
    def merge_data_helper(cls, data: Iterable[np.ndarray]) -> np.ndarray:
        assert isinstance(data, Iterable), f'data must be Iterable: {data}'
        assert all(isinstance(item, np.ndarray) for item in data), f'data items must be np.ndarray: {data}'
        return np.array(data)

    def merge_data(self, data: Iterable[np.ndarray]) -> np.ndarray:
        return self.merge_data_helper(data)

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            'data_shape': list(self.data_shape),
            'data_type': get_class_name(self.data_type)
        }


class ConcatNumpyDataset(ConcatDataset[np.ndarray]):
    def __init__(self, datasets: list[NumpyDataset]) -> None:
        assert isinstance(datasets, list), f'datasets must be list: {datasets}'
        assert all(isinstance(dataset, NumpyDataset) for dataset in datasets), \
            f'all datasets must be NumpyDataset: {datasets}'
        assert len(datasets) > 0, f'datasets must have at least 1 element: {datasets}'
        data_shape: tuple[int, ...] = datasets[0].data_shape
        data_type: type = datasets[0].data_type
        assert all(dataset.data_shape == data_shape for dataset in datasets), \
            f'all datasets must have the same shape: {datasets}'
        assert all(dataset.data_type == data_type for dataset in datasets), \
            f'all datasets must have the same type: {datasets}'
        super().__init__(datasets)
        self.data_shape: tuple[int, ...] = data_shape
        self.data_type: type = data_type

    def merge_data(self, data: Iterable[np.ndarray]) -> np.ndarray:
        return NumpyDataset.merge_data_helper(data)


class ArrayNumpyDataset(NumpyDataset):
    def __init__(self, array: np.ndarray) -> None:
        assert isinstance(array, np.ndarray), f'array must be np.ndarray: {array}'
        assert len(array.shape) >= 1, f'array must have at least 1 dimension: {array.shape}'
        super().__init__(array.shape[1:], array.dtype.type)
        self.array: np.ndarray = array

    def __len__(self) -> int:
        return self.array.shape[0]

    def __getitem__(self, item: int) -> np.ndarray:
        return self.array[item]

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            'array': self.array.tolist()
        }


class ConstantNumpyDataset(NumpyDataset):
    def __init__(self, value: np.ndarray, length: int) -> None:
        assert isinstance(value, np.ndarray), f'value must be np.ndarray: {value}'
        assert isinstance(length, int), f'length must be int: {length}'
        assert length >= 0, f'length must be greater than or equal to 0: {length}'
        super().__init__(value.shape, value.dtype.type)
        self.value: np.ndarray = value
        self.length: int = length

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, item: int) -> np.ndarray:
        return self.value

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            'value': self.value.tolist(),
            'length': self.length
        }


class FolderNumpyDataset(NumpyDataset):
    def __init__(
            self,
            data_shape: tuple[int, ...],
            data_type: type,
            folder: str,
            num_samples: int,
            start_index: int = 0
    ) -> None:
        assert isinstance(folder, str), f'folder must be str: {folder}'
        assert isinstance(num_samples, int), f'num_samples must be int: {num_samples}'
        assert num_samples >= 0, f'num_samples must be greater than or equal to 0: {num_samples}'
        assert isinstance(start_index, int), f'start_index must be int: {start_index}'
        assert start_index >= 0, f'start_index must be greater than or equal to 0: {start_index}'
        super().__init__(data_shape, data_type)
        self.folder: str = folder
        self.num_samples: int = num_samples
        self.start_index: int = start_index

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, item: int) -> np.ndarray:
        return np.load(f'{self.folder}/{item + self.start_index}.npy')
