from torch.utils.data import DataLoader
from utils.build_transforms import build_transforms
from data.benchmark_mir import CLDataLoader, get_permuted_mnist, get_split_mnist, get_miniimagenet, get_rotated_mnist, \
    get_split_cifar10, IIDDataset

from utils.utils import DotDict

import random

_dataset = {

}

class BatchCollator(object):
    def __init__(self):
        pass

    def __call__(self, batch):
        out_dict = {
            'image_ids': [],
            'images': [],
            'gt_bboxes': [],
            'info': [],
            'attribute_labels': [],
            'object_labels': [],
            'cropped_image': [],
            'raw': []
        }

        for item in batch:
            out_dict['image_ids'].append(item['image_id'])
            out_dict['images'].append(item['image'])
            out_dict['gt_bboxes'].append(item['gt_bboxes'])

            gt_bbox = item['gt_bboxes']
            out_dict['attribute_labels'].append(gt_bbox.extra_fields['attributes'])
            out_dict['object_labels'].append(gt_bbox.extra_fields['labels'])
            #out_dict['spatial_feat'].append(item['info'][0]['spatial_feature'])
            #out_dict['object_feat'].append(item['info'][0]['object_feature'])
            out_dict['cropped_image'].append(item['cropped_image'])
            # assert(len(item['cropped_image']) == 1)
            out_dict['raw'].append(item)
        return out_dict

_smnist_loaders = None
def get_split_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
    d = DotDict()
    global _smnist_loaders
    if not _smnist_loaders:
        data = get_split_mnist(d)
        train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
            for elem, t in zip(data, [True, False, False])]
        _smnist_loaders = train_loader, val_loader, test_loader
    else:
        train_loader, val_loader, test_loader = _smnist_loaders

    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]

_rmnist_loaders = None
def get_rotated_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs):
    d = DotDict()
    global _rmnist_loaders
    if not _rmnist_loaders:
        data = get_rotated_mnist(d)
        train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _rmnist_loaders = train_loader, val_loader, test_loader
    else:
        train_loader, val_loader, test_loader = _rmnist_loaders
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]

_pmnist_loaders = None
def get_permute_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs):
    d = DotDict()
    global _pmnist_loaders
    if not _pmnist_loaders:
        data = get_permuted_mnist(d)
        train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
            for elem, t in zip(data, [True, False, False])]
        _pmnist_loaders = train_loader, val_loader, test_loader
    else:
        train_loader, val_loader, test_loader = _pmnist_loaders
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]

_cache_cifar = None
def get_split_cifar_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
    d = DotDict()
    global _cache_cifar
    if not _cache_cifar:
        data = get_split_cifar10(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
        train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _cache_cifar = train_loader, val_loader, test_loader
    train_loader, val_loader, test_loader = _cache_cifar
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]


_cache_mini_imagenet = None
def get_split_mini_imagenet_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
    global _cache_mini_imagenet
    d = DotDict()
    if not _cache_mini_imagenet:
        data = get_miniimagenet(d)
        train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _cache_mini_imagenet = train_loader, val_loader, test_loader
    train_loader, val_loader, test_loader = _cache_mini_imagenet
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
