from .base_dataset import BaseDataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from .dataset_wrappers import RepeatDataset
from query_strategies.utils import get_one_hot_label


class Handler(Dataset):
    def __init__(self, dataset: BaseDataset, split, transform=None, to_soft=False):
        self.dataset = dataset
        self.split = split
        self.transform = transform
        self.to_soft = to_soft

    def __getitem__(self, idx):
        if self.to_soft:
            x, y, no, idx = self.dataset.prepare_data(idx, self.split, self.transform)
            y = get_one_hot_label(y, len(self.dataset.CLASSES))
            return x, y, no, idx
        else:
            return self.dataset.prepare_data(idx, self.split, self.transform)

    def __len__(self):
        return len(self.dataset.DATA_INFOS[self.split])


class AugHandler(Handler):
    def __init__(self, dataset: BaseDataset, split, transform=None, to_soft=False):
        super(AugHandler, self).__init__(dataset, split, transform, to_soft)
        

    def __getitem__(self, idx):
        aug_data_info = self.dataset.DATA_INFOS[self.split]
        if self.to_soft:
            x, y, no, _ = self.dataset.prepare_data(aug_data_info[idx]['idx'], aug_data_info[idx]['split'],
                                                    self.transform, aug_data_info[idx]['aug_transform'])
            y = get_one_hot_label(y, len(self.dataset.CLASSES))
            return x, y, -1, idx
        else:
            return self.dataset.prepare_data(aug_data_info[idx]['idx'], aug_data_info[idx]['split'],
                                             self.transform, aug_data_info[idx]['aug_transform'])


class MixUpHandler(Handler):
    def __init__(self, dataset: BaseDataset, split, transform=None, to_soft=False, split_b=None):
        super(MixUpHandler, self).__init__(dataset, split, transform, to_soft)
        

    def __getitem__(self, idx):
        
        mix_data_info = self.dataset.DATA_INFOS[self.split]
        lam = mix_data_info[idx]['lam']
        x1, y1, no1, _ = self.dataset.prepare_data(mix_data_info[idx]['idx_a'],
                                                   mix_data_info[idx]['split'], self.transform)
        if 'split_b' not in mix_data_info[idx].keys():
            x2, y2, no2, _ = self.dataset.prepare_data(mix_data_info[idx]['idx_b'],
                                                       mix_data_info[idx]['split'], self.transform)
        else:
            x2, y2, no2, _ = self.dataset.prepare_data(mix_data_info[idx]['idx_b'],
                                                       mix_data_info[idx]['split_b'], self.transform)
        y1, y2 = get_one_hot_label(y1, len(self.dataset.CLASSES)), get_one_hot_label(y2, len(self.dataset.CLASSES))
        x = lam * x1 + (1 - lam) * x2
        y = lam * y1 + (1 - lam) * y2
        return x, y, -2, idx


loader_dict = {'base': Handler, 'aug': AugHandler, 'mixup': MixUpHandler}


def GetHandler(dataset: BaseDataset, split: str, transform=None,
               repeat=None, to_soft=None, loader_name=None):
    
    if type(split) == str:
        if to_soft is None:
            to_soft = False
        if loader_name is None:
            loader_name = 'base'
        if repeat is None:
            repeat = 1
        h = RepeatDataset(loader_dict[loader_name](dataset, split, transform, to_soft), repeat)
    elif type(split) in [list, tuple, set]:
        if to_soft is None:
            to_soft = [False for _ in split]
        if repeat is None:
            repeat = [1 for _ in split]
        if loader_name is None:
            loader_name = ['base' for _ in split]
        assert len(split) == len(repeat)
        assert len(split) == len(to_soft)
        h_list = [RepeatDataset(loader_dict[loader_name](dataset, split_elem, transform, to_soft), repeat_times)
                  for split_elem, repeat_times, to_soft, loader_name in zip(split, repeat, to_soft, loader_name)]
        h = ConcatDataset(h_list)
    else:
        raise Exception("Not supported type!")
    return h


def GetDataLoader(dataset: BaseDataset, split: str, transform=None,
                  repeat=None, to_soft=None, loader_name=None, **kwargs):
    h = GetHandler(dataset, split, transform, repeat, to_soft, loader_name)
    return DataLoader(h, **kwargs)
