import bisect
import numpy as np
import torch
from copy import deepcopy

from typing import List, Tuple
from sklearn.model_selection import train_test_split
from torchvision.transforms import Compose
from torch.utils.data import ConcatDataset

from xad.datasets.bases import TorchvisionDataset, DataLoader, BaseLoader, RandomSampler, torchvision_to_kornia
from xad.utils.logger import Logger


class ConcatOEDataset(TorchvisionDataset):
    def __init__(self, datasets: List[TorchvisionDataset], balance=False,):
        """
        Concatenates TorchvisionDataset datasets.
        @param datasets: the datasets.
        """
        assert all([datasets[0].nominal_label == d.nominal_label for d in datasets])
        assert all([str(datasets[0].train_transform) == str(d.train_transform) for d in datasets])
        assert all([str(datasets[0].gpu_train_transform) == str(d.gpu_train_transform) for d in datasets])
        super().__init__(
            None, [0], datasets[0].nominal_label, datasets[0].train_transform, Compose([]), 1,
            None, datasets[0].logger, np.sum(ds.limit_samples for ds in datasets),
        )

        self.datasets = datasets
        self.balanced = balance
        if balance:
            subset_sizes = [len(ds.train_set.indices) for ds in datasets]
            for subset, size in zip((ds.train_set for ds in datasets), subset_sizes):
                if size < max(subset_sizes):
                    subset.indices = np.repeat(
                        subset.indices, int(np.ceil((max(subset_sizes) - size) / size)) + 1
                    )[:max(subset_sizes)]
            self.logger.print(
                f'Repeated OE samples to balance OE datasets. Initial OE dataset sizes were {subset_sizes} for {datasets}.'
            )

        self._train_set: ConcatSubset = ConcatSubset([ds.train_set for ds in datasets])
        self._val_set: ConcatSubset = None
        self._test_set = None
        self.gpu_train_transform = datasets[0].gpu_train_transform
        self.gpu_test_transform = Compose([])

    def train_split(self, ratio=0.1):
        if ratio <= 0.0:
            return
        if self.val_set is not None:
            raise ValueError('Validation split already exists!')
        train_subsets, val_subsets = [], []
        for subset in self._train_set.subsets:
            if self.balanced or len(subset.indices) != len(set(subset.indices)):
                raise ValueError('Cannot perform val split on ConcatSubset dataset with subsets that contain duplicates.')
            train_indices, val_indices = train_test_split(subset.indices, test_size=ratio)
            subset.indices = train_indices
            val_subset = torch.utils.data.Subset(subset.dataset, val_indices)
            train_subsets.append(subset)
            val_subsets.append(val_subset)
        self._train_set = ConcatSubset(train_subsets)
        self._val_set = ConcatSubset(val_subsets)

    def val_loader(self, batch_size: int, shuffle=False, num_workers: int = 0, device=torch.device('cuda:0')) -> DataLoader:
        return super().val_loader(batch_size, shuffle, num_workers, device)

    def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, replacement=False,
                num_workers: int = 0, persistent=False, device=torch.device('cuda:0')) -> Tuple[DataLoader, DataLoader]:
        # classes = None means all classes
        return super().loaders(
            batch_size, shuffle_train, shuffle_test, replacement, num_workers, persistent, device
        )[0], None

    def n_normal_anomalous(self, train=True) -> dict:
        assert train, 'ConcatOE has no test set.'
        counters = [ds.n_normal_anomalous() for ds in self.datasets]
        return {k: int(np.sum(c.pop(k, 0) for c in counters)) for k in set([k for c in counters for k in c.keys()])}

    def _get_raw_train_set(self):
        raise NotImplementedError("Since this is only used as OE, this should never be used...")


class ConcatSubset(torch.utils.data.Subset):
    def __init__(self, subsets: List[torch.utils.data.Subset]):
        self.subsets = subsets
        lengths = [len(ds) for ds in subsets]
        self.subsets_cumlen = [int(np.sum(lengths[:i])) for i in range(len(lengths)+1)]

    def repeat_indices(self, by: int = 2):
        for subset in self.subsets:
            subset.indices = np.asarray(
                subset.indices
            ).reshape(1, -1).repeat(by, axis=0).reshape(-1).tolist()
        lengths = [len(ds) for ds in self.subsets]
        self.subsets_cumlen = [int(np.sum(lengths[:i])) for i in range(len(lengths) + 1)]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.subsets_cumlen, idx) - 1
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.subsets_cumlen[dataset_idx]
        return self.subsets[dataset_idx][sample_idx]

    def __len__(self):
        return self.subsets_cumlen[-1]
