import torch
import torch.utils.data as data_utils
from torchvision import transforms

from torchmeta.datasets import MiniImagenet, TieredImagenet, CUB, CIFARFS
from torchmeta.transforms import ClassSplitter, Categorical

from data.aircraft import AirCraft
from data.vggflower import VggFlower
from data.cars import CARS
from data.pose import Pascal1D
from data.shapenet1d import ShapeNet1D
from data.utils import CifarFSTrasnform, SimpleTransform, NoneTransform, RandAugTransform, AblationCifarFSTrasnform

DATA_PATH = '/input/dataset/'

class ToTensor1D(object):
    """Convert a `numpy.ndarray` to tensor. Unlike `ToTensor` from torchvision,
    this converts numpy arrays regardless of the number of dimensions.
    Converts automatically the array to `float32`.
    """
    def __call__(self, array):
        return torch.from_numpy(array.astype('float32'))

    def __repr__(self):
        return self.__class__.__name__ + '()'


def resize_transform(resize_size):
    transform = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.ToTensor()
    ])
    return transform


def get_meta_dataset(P, dataset, only_test=False):
    """
    Load dataloaders for an image dataset, center-cropped to a resolution.
    """

    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=P.num_shots,
                                      num_test_per_class=P.num_shots_test + P.num_shots_global)
    dataset_transform_test = ClassSplitter(shuffle=True,
                                           num_train_per_class=P.num_shots,
                                           num_test_per_class=P.num_shots_test)

    if 'protonet' in P.mode:
        train_num_ways = P.train_num_ways
    else:
        train_num_ways = P.num_ways

    if dataset == 'miniimagenet':
        if P.barlow or P.selfsup:
            if P.trades_only:
                transform = resize_transform(P.crop_size)
                test_transform = resize_transform(P.crop_size)
            else:
                if P.aug_type == 'selfsup':
                    P.transform_kwargs = {
                        'brightness':P.brightness,'contrast':P.contrast,'saturation':P.saturation,'hue':P.hue,
                        'color_jitter_prob':P.color_jitter_prob,'gray_scale_prob':P.gray_scale_prob,
                        'horizontal_flip_prob':P.horizontal_flip_prob,
                        'gaussian_prob':P.gaussian_prob,'solarization_prob':P.solarization_prob,
                        'crop_size':P.crop_size,'min_scale':P.min_scale,'max_scale':P.max_scale,
                        }
                    transform = CifarFSTrasnform(adv_train=P.adv, **P.transform_kwargs)
                elif P.aug_type == 'simple':
                    P.transform_kwargs = {
                        'crop_size':P.crop_size,'min_scale':P.min_scale,'max_scale':P.max_scale,
                        }
                    transform = SimpleTransform(adv_train=P.adv, **P.transform_kwargs)
                elif P.aug_type == 'none':
                    P.transform_kwargs = {
                        'crop_size':P.crop_size
                    }
                    transform = NoneTransform(adv_train=P.adv, **P.transform_kwargs)
                elif P.aug_type == 'randaug':
                    transform = RandAugTransform(adv_train=P.adv)
                test_transform = resize_transform(P.crop_size)
        else:
            transform = resize_transform(P.crop_size)
            test_transform = resize_transform(P.crop_size)

        if P.subsampling:
            from torchmeta.datasets.helpers import miniimagenet
            from torchmeta.utils.data import BatchMetaDataLoader

            meta_train_dataset = miniimagenet(
                DATA_PATH, 
                ways=P.num_ways,
                shots=P.num_shots,
                test_shots=P.num_shots_test,
                meta_split='train',
                transform=transform,
                download=True
            )
            meta_train_dataset = data_utils.Subset(meta_train_dataset, torch.arange(int(P.batch_size*200)))
            
            meta_val_dataset = miniimagenet(
                DATA_PATH, 
                ways=P.num_ways,
                shots=P.num_shots,
                test_shots=P.num_shots_test,
                meta_split='val',
                transform=test_transform,
                download=True
            )
            meta_val_dataset = data_utils.Subset(meta_val_dataset, torch.arange(int(1*200)))

            meta_test_dataset = miniimagenet(
                DATA_PATH, 
                ways=P.num_ways,
                shots=P.num_shots,
                test_shots=P.num_shots_test,
                target_transform=Categorical(train_num_ways),
                                            
                meta_split='test',
                transform=test_transform,
                download=True
            )
            #meta_test_dataset = data_utils.Subset(meta_test_dataset, torch.arange(int(1*200)))

        else:
            
            meta_train_dataset = MiniImagenet(DATA_PATH,
                                            transform=transform,
                                            target_transform=Categorical(train_num_ways),
                                            num_classes_per_task=train_num_ways,
                                            meta_train=True,
                                            dataset_transform=dataset_transform,
                                            download=True)
            meta_val_dataset = MiniImagenet(DATA_PATH,
                                            transform=transform,
                                            target_transform=Categorical(P.num_ways),
                                            num_classes_per_task=P.num_ways,
                                            meta_val=True,
                                            dataset_transform=dataset_transform_test)
            meta_test_dataset = MiniImagenet(DATA_PATH,
                                            transform=transform,
                                            target_transform=Categorical(P.num_ways),
                                            num_classes_per_task=P.num_ways,
                                            meta_test=True,
                                            dataset_transform=dataset_transform_test)

    elif dataset == 'tieredimagenet':
        transform = resize_transform(P.crop_size)

        meta_train_dataset = TieredImagenet(DATA_PATH,
                                            transform=transform,
                                            target_transform=Categorical(train_num_ways),
                                            num_classes_per_task=train_num_ways,
                                            meta_train=True,
                                            dataset_transform=dataset_transform,
                                            download=True)
        meta_val_dataset = TieredImagenet(DATA_PATH,
                                          transform=transform,
                                          target_transform=Categorical(P.num_ways),
                                          num_classes_per_task=P.num_ways,
                                          meta_val=True,
                                          dataset_transform=dataset_transform_test)
        meta_test_dataset = TieredImagenet(DATA_PATH,
                                           transform=transform,
                                           target_transform=Categorical(P.num_ways),
                                           num_classes_per_task=P.num_ways,
                                           meta_test=True,
                                           dataset_transform=dataset_transform_test)

    elif dataset == 'cub':
        assert only_test
        transform = transforms.Compose([
            transforms.Resize(int(P.crop_size * 1.5)),
            transforms.CenterCrop(P.crop_size),
            transforms.ToTensor()
        ])

        meta_test_dataset = CUB(DATA_PATH,
                                transform=transform,
                                target_transform=Categorical(P.num_ways),
                                num_classes_per_task=P.num_ways,
                                meta_test=True,
                                dataset_transform=dataset_transform_test)

    elif dataset == 'cars':
        assert only_test
        transform = resize_transform(P.crop_size)
        meta_test_dataset = CARS(DATA_PATH,
                                 transform=transform,
                                 target_transform=Categorical(P.num_ways),
                                 num_classes_per_task=P.num_ways,
                                 meta_test=True,
                                 dataset_transform=dataset_transform_test)
    elif dataset == 'vggflower':
        assert only_test
        
        transform = resize_transform(P.crop_size)
        meta_test_dataset = VggFlower(DATA_PATH,
                                 transform=transform,
                                 target_transform=Categorical(P.num_ways),
                                 num_classes_per_task=P.num_ways,
                                 meta_test=True,
                                 dataset_transform=dataset_transform_test)
    elif dataset == 'aircraft':
        assert only_test
        transform = resize_transform(P.crop_size)
        meta_test_dataset = AirCraft(DATA_PATH,
                                 transform=transform,
                                 target_transform=Categorical(P.num_ways),
                                 num_classes_per_task=P.num_ways,
                                 meta_test=True,
                                 dataset_transform=dataset_transform_test)

    elif dataset == 'shapenet':
        P.regression = True
        P.num_ways = 2
        meta_train_dataset = ShapeNet1D(path=f'{DATA_PATH}/ShapeNet1D',
                                        img_size=[128, 128, 1],
                                        seed=P.seed,
                                        source='train',
                                        shot=P.num_shots,
                                        tasks_per_batch=P.batch_size)

        meta_val_dataset = ShapeNet1D(path=f'{DATA_PATH}/ShapeNet1D',
                                      img_size=[128, 128, 1],
                                      seed=P.seed,
                                      source='val',
                                      shot=P.num_shots,
                                      tasks_per_batch=P.batch_size)

        meta_test_dataset = ShapeNet1D(path=f'{DATA_PATH}/ShapeNet1D',
                                       img_size=[128, 128, 1],
                                       seed=P.seed,
                                       source='test',
                                       shot=P.num_shots,
                                       tasks_per_batch=P.batch_size)

    elif dataset == 'pose':
        P.regression = True
        P.num_ways = 1
        meta_train_dataset = Pascal1D(path=f'{DATA_PATH}/Pascal1D',
                                      img_size=[128, 128, 1],
                                      seed=P.seed,
                                      source='train',
                                      shot=P.num_shots,
                                      tasks_per_batch=P.batch_size)

        meta_val_dataset = Pascal1D(path=f'{DATA_PATH}/Pascal1D',
                                    img_size=[128, 128, 1],
                                    seed=P.seed,
                                    source='val',
                                    shot=P.num_shots,
                                    tasks_per_batch=P.batch_size)

        meta_test_dataset = meta_val_dataset

    elif dataset == 'cifar_fs':
        if P.barlow or P.selfsup:
            if P.ablation:
                P.transform_kwargs = {
                    'brightness':P.brightness,'contrast':P.contrast,'saturation':P.saturation,'hue':P.hue,
                    'color_jitter_prob':P.color_jitter_prob,'gray_scale_prob':P.gray_scale_prob,
                    'horizontal_flip_prob':P.horizontal_flip_prob,
                    'gaussian_prob':P.gaussian_prob,'solarization_prob':P.solarization_prob,
                    'crop_size':P.crop_size,'min_scale':P.min_scale,'max_scale':P.max_scale,
                    }
                transform = AblationCifarFSTrasnform(adv_train=P.adv, **P.transform_kwargs)
            else:
                P.transform_kwargs = {
                    'brightness':P.brightness,'contrast':P.contrast,'saturation':P.saturation,'hue':P.hue,
                    'color_jitter_prob':P.color_jitter_prob,'gray_scale_prob':P.gray_scale_prob,
                    'horizontal_flip_prob':P.horizontal_flip_prob,
                    'gaussian_prob':P.gaussian_prob,'solarization_prob':P.solarization_prob,
                    'crop_size':P.crop_size,'min_scale':P.min_scale,'max_scale':P.max_scale,
                    }
                transform = CifarFSTrasnform(adv_train=P.adv, **P.transform_kwargs)
            
            test_transform = resize_transform(P.crop_size)
        else:
            transform = resize_transform(P.crop_size)
            test_transform = resize_transform(P.crop_size)
        

        if P.subsampling:
            from torchmeta.datasets.helpers import cifar_fs
            from torchmeta.utils.data import BatchMetaDataLoader

            meta_train_dataset = cifar_fs(
                DATA_PATH, 
                ways=P.num_ways,
                shots=P.num_shots,
                test_shots=P.num_shots_test,
                #target_transform=Categorical(train_num_ways),
                                            
                meta_split='train',
                transform=transform,
                download=True
            )
            meta_train_dataset = data_utils.Subset(meta_train_dataset, torch.arange(int(P.batch_size*200)))
            
            meta_val_dataset = cifar_fs(
                DATA_PATH, 
                ways=P.num_ways,
                shots=P.num_shots,
                test_shots=P.num_shots_test,
                #target_transform=Categorical(train_num_ways),                      
                meta_split='val',
                transform=test_transform,
                download=True
            )
            meta_val_dataset = data_utils.Subset(meta_val_dataset, torch.arange(int(1*200)))

            meta_test_dataset = cifar_fs(
                DATA_PATH, 
                ways=P.num_ways,
                shots=P.num_shots,
                test_shots=P.num_shots_test,
                target_transform=Categorical(train_num_ways),
                                            
                meta_split='test',
                transform=test_transform,
                download=True
            )
            #meta_test_dataset = data_utils.Subset(meta_test_dataset, torch.arange(int(1*200)))

        else:
            meta_train_dataset = CIFARFS(DATA_PATH,
                                            transform=transform,
                                            target_transform=Categorical(train_num_ways),
                                            num_classes_per_task=train_num_ways,
                                            meta_train=True,
                                            dataset_transform=dataset_transform,
                                            download=True)
            
            meta_val_dataset = CIFARFS(DATA_PATH,
                                            transform=test_transform,
                                            target_transform=Categorical(P.num_ways),
                                            num_classes_per_task=P.num_ways,
                                            meta_val=True,
                                            dataset_transform=dataset_transform_test)
        
            meta_test_dataset = CIFARFS(DATA_PATH,
                                         transform=test_transform,
                                         target_transform=Categorical(P.num_ways),
                                         num_classes_per_task=P.num_ways,
                                         meta_test=True,
                                         dataset_transform=dataset_transform_test)

    else:
        raise NotImplementedError()

    if only_test:
        return meta_test_dataset

    return meta_train_dataset, meta_val_dataset
