from abc import ABC
from typing import Iterable

import numpy as np

from utils.types import JSON
from utils.utils import get_class_name
from .datasets import Dataset, ConcatDataset
from .numpy import NumpyDataset


class DictDataset(Dataset[dict[str, np.ndarray]], ABC):
    def __init__(self, data_shape: dict[str, tuple[int, ...]], data_type: dict[str, type]) -> None:
        assert isinstance(data_shape, dict), f'data_shape must be dict: {data_shape}'
        assert all(isinstance(key, str) for key in data_shape.keys()), f'data_shape keys must be str: {data_shape}'
        assert all(isinstance(val, tuple) for val in data_shape.values()), \
            f'data_shape values must be tuple: {data_shape}'
        assert all(all(isinstance(val, int) for val in shape) for shape in data_shape.values()), \
            f'data_shape values must be tuple of int: {data_shape}'
        assert isinstance(data_type, dict), f'data_type must be dict: {data_type}'
        assert all(isinstance(key, str) for key in data_type.keys()), f'data_type keys must be str: {data_type}'
        assert all(isinstance(val, type) for val in data_type.values()), \
            f'data_type values must be type: {data_type}'
        assert set(data_shape.keys()) == set(data_type.keys()), \
            f'data_shape keys must be the same as data_type keys: {data_shape.keys()}, {data_type.keys()}'
        self.keys: set[str] = set(data_shape.keys())
        self.data_shape: dict[str, tuple[int, ...]] = data_shape
        self.data_type: dict[str, type] = data_type

    @classmethod
    def merge_data_helper(
            cls,
            data: Iterable[dict[str, np.ndarray]],
            keys: set[str],
            data_shape: dict[str, tuple[int, ...]],
            data_type: dict[str, type]
    ) -> dict[str, np.ndarray]:
        assert isinstance(data, Iterable), f'data must be Iterable: {data}'
        assert all(isinstance(item, dict) for item in data), f'data items must be dict: {data}'
        assert all(all(isinstance(key, str) for key in item.keys()) for item in data), \
            f'data items keys must be str: {data}'
        assert all(set(item.keys()) == keys for item in data), \
            f'data items keys must be the same as keys: {data}'
        assert all(all(isinstance(item[key], np.ndarray) for key in keys) for item in data), \
            f'data items values must be np.ndarray: {data}'
        assert all(all(item[key].shape == data_shape[key] for key in keys) for item in data), \
            f'data items shapes must be the same as shapes: {data}'
        assert all(all(issubclass(item[key].dtype.type, data_type[key]) for key in keys) for item in data), \
            f'data items types must be the same as types: {data}'
        result: dict[str, np.ndarray] = {}
        for key in keys:
            result[key] = np.array([d[key] for d in data])
        return result

    def merge_data(self, data: Iterable[dict[str, np.ndarray]]) -> dict[str, np.ndarray]:
        return self.merge_data_helper(data, self.keys, self.data_shape, self.data_type)

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            'keys': list(sorted(self.keys)),
            'data_shape': {key: list(shape) for key, shape in self.data_shape.items()},
            'data_type': {key: get_class_name(val) for key, val in self.data_type.items()}
        }


class ConcatDictDataset(ConcatDataset[dict[str, np.ndarray]]):
    def __init__(self, datasets: list[DictDataset]) -> None:
        assert isinstance(datasets, list), f'datasets must be list: {datasets}'
        assert all(isinstance(dataset, DictDataset) for dataset in datasets), \
            f'all datasets must be DictDataset: {datasets}'
        assert len(datasets) > 0, f'datasets must have at least 1 element: {datasets}'
        keys: set[str] = datasets[0].keys
        data_shape: dict[str, tuple[int, ...]] = datasets[0].data_shape
        data_type: dict[str, type] = datasets[0].data_type
        assert all(dataset.keys == keys for dataset in datasets), \
            f'all datasets must have the same keys: {datasets}'
        assert all(all(dataset.data_shape[k] == data_shape[k] for k in keys) for dataset in datasets), \
            f'all datasets must have the same shape: {datasets}'
        assert all(all(dataset.data_type[k] == data_type[k] for k in keys) for dataset in datasets), \
            f'all datasets must have the same type: {datasets}'
        super().__init__(datasets)
        self.keys: set[str] = keys
        self.data_shape: dict[str, tuple[int, ...]] = data_shape
        self.data_type: dict[str, type] = data_type

    def merge_data(self, data: Iterable[dict[str, np.ndarray]]) -> dict[str, np.ndarray]:
        return DictDataset.merge_data_helper(data, self.keys, self.data_shape, self.data_type)


class NumpyDictDataset(DictDataset):
    def __init__(self, datasets: dict[str, NumpyDataset]) -> None:
        assert isinstance(datasets, dict), f'datasets must be dict: {datasets}'
        assert all(isinstance(key, str) for key in datasets.keys()), f'datasets keys must be str: {datasets}'
        assert all(isinstance(dataset, NumpyDataset) for dataset in datasets.values()), \
            f'datasets values must be NumpyDataset: {datasets}'
        assert len(datasets) > 0, f'datasets must not be empty: {datasets}'
        length: int = len(datasets[list(sorted(datasets.keys()))[0]])
        assert all(len(dataset) == length for dataset in datasets.values()), \
            f'all datasets must have the same length: {datasets}'
        data_shape: dict[str, tuple[int, ...]] = {key: dataset.data_shape for key, dataset in datasets.items()}
        data_type: dict[str, type] = {key: dataset.data_type for key, dataset in datasets.items()}
        super().__init__(data_shape, data_type)
        self.datasets: dict[str, NumpyDataset] = datasets
        self.length: int = length

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, item: int) -> dict[str, np.ndarray]:
        return {key: dataset[item] for key, dataset in self.datasets.items()}

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            'datasets': {key: dataset.get_json() for key, dataset in self.datasets.items()},
            'length': self.length
        }
