from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Iterable

import torch.utils.data

from utils.base_object import BaseObject
from utils.types import JSON

T_co: 'T_co' = TypeVar('T_co', covariant=True)


class Dataset(torch.utils.data.Dataset[T_co], Generic[T_co], BaseObject, ABC):
    @abstractmethod
    def __len__(self) -> int:
        raise NotImplementedError('__len__ method must be implemented')

    @abstractmethod
    def __getitem__(self, item: int) -> T_co:
        raise NotImplementedError('__getitem__ method must be implemented')

    @abstractmethod
    def merge_data(self, data: Iterable[T_co]) -> T_co:
        raise NotImplementedError('merge_data method must be implemented')

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            '__len__': len(self)
        }


class ConcatDataset(Dataset[T_co], Generic[T_co], ABC):
    def __init__(self, datasets: list[T_co]) -> None:
        assert isinstance(datasets, list), f'datasets must be list: {datasets}'
        self.datasets: list[T_co] = datasets
        self.lengths: list[int] = [len(dataset) for dataset in datasets]
        self.starts: list[int] = [sum(self.lengths[:i]) for i in range(len(self.lengths))]
        self.ends: list[int] = [sum(self.lengths[:i + 1]) for i in range(len(self.lengths))]
        self.length: int = sum(self.lengths)

    def __len__(self) -> int:
        return self.length

    def get_dataset_and_index(self, index: int) -> (int, int):
        for i, (start, end) in enumerate(zip(self.starts, self.ends)):
            if start <= index < end:
                return i, index - start
        raise ValueError(f'item {index} not in any dataset')

    def __getitem__(self, item: int) -> T_co:
        dataset_index, dataset_item = self.get_dataset_and_index(item)
        return self.datasets[dataset_index][dataset_item]

    def get_json(self) -> dict[str, JSON]:
        return {
            **super().get_json(),
            'datasets': [dataset.get_json() for dataset in self.datasets],
            'lengths': self.lengths,
            'starts': self.starts,
            'ends': self.ends,
            'length': self.length
        }
