__author__ = "Anon"
__version__ = "0.1"

import torchvision
import torch
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
from PIL import Image
from utils import NORM
from dataloaders.bmw10 import BMW10
from dataloaders.inn_loaders import *
from utils import BinImageIntensities, SwitchChannels, RandomSwitchChannels, NegativeImages, PermutePixels, innet_collate, RepeatChannels
from torch.utils.data._utils.collate import default_collate


def get_train_valid_loader(root='./data', dataset='mnist', batch_size=64, num_workers=8, pin_memory=True, transform=None, target_transform=None, shuffle=True, opts=None, is_mgpu=False):
    collate = default_collate
    if dataset.lower() == 'mnist':
        dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    elif dataset.lower() == 'cifar10':
        dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    elif dataset.lower() == 'cifar100':
        dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    elif dataset.lower() == 'fashionmnist':
        dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    elif dataset.lower() == 'cub200':
        dataset = Cub2011(root='./data/data', train=True, transform=transform)
    elif dataset.lower() == 'cub20':
        dataset = Cub2011_MOD(root='./data/cub20', train=True, transform=transform, download=False)
    elif dataset.lower() == 'stl10':
        dataset = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)
    elif dataset.lower() == 'tiny':
        dataset = torchvision.datasets.ImageFolder(root='./data/tiny-64/train', transform=transform)
    elif dataset.lower() == 'cub':
        dataset = Cub2011(root='./data/data', train=True, transform=transform)
    elif dataset.lower() == 'bmw10':
        dataset = BMW10(root='./data/bmw10', image_set='train', transform=transform)
    elif dataset.lower() in ['pets', 'opets']:
        dataset = Pets(root='./data/pets', image_set='trainval', transform=transform)
    elif dataset.lower() == 'inn_stl10':
        dataset = InnSTL10(root=root, split='train', img_transform=transform, num_negatives=opts['num_negatives'], neg_factor=opts['neg_red_factor'])
        collate = innet_collate
    elif dataset.lower() == 'inn_cifar100':
        dataset = InnCIFAR100(root=root, train=True, img_transform=transform, num_negatives=opts['num_negatives'], neg_factor=opts['neg_red_factor'])
        collate = innet_collate
    elif dataset.lower() == 'inn_imgnet':
        dataset = InnImgNet(root='data/ILSVRC/Data/CLS-LOC/', split='train', img_transform=transform, num_negatives=opts['num_negatives'], neg_factor=opts['neg_red_factor'])
        collate = innet_collate
    elif dataset.lower() == 'inn_cifar10':
        dataset = InnCIFAR10(root=root, train=True, img_transform=transform, num_negatives=opts['num_negatives'], neg_factor=opts['neg_red_factor'])
        collate = innet_collate
    elif dataset.lower() == 'inn_cub200':
        dataset = InnCub(root=root, train=True, img_transform=transform, num_negatives=opts['num_negatives'], neg_factor=opts['neg_red_factor'])
        collate = innet_collate
    elif dataset.lower() == 'inn_cub20':
        dataset = InnCub20(root='./data/cub20', train=True, img_transform=transform, num_negatives=opts['num_negatives'], neg_factor=opts['neg_red_factor'], download=False)
        collate = innet_collate
    elif dataset.lower() in ['inn_opets', 'inn_pets']:
        dataset = InnPets(root=root+'/pets/', image_set='trainval', img_transform=transform, num_negatives=opts['num_negatives'],
                         neg_factor=opts['neg_red_factor'])
        collate = innet_collate 
    elif dataset.lower() == 'inn_bmw10':
        dataset = InnBMW10(root='./data/bmw10/', image_set='train', img_transform=transform, num_negatives=opts['num_negatives'],
                         neg_factor=opts['neg_red_factor'])
        collate = innet_collate
    else:
        raise AssertionError('Dataset {} is currently not supported'.format(dataset))

    if not is_mgpu:
        return torch.utils.data.DataLoader(
                dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, collate_fn=collate), None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True, collate_fn=collate
                            , sampler=sampler, num_workers=num_workers)
        return loader, sampler

def get_test_loader(root='./data', dataset='mnist', batch_size=64, num_workers=8, pin_memory=True, transform=None, target_transform=None):

    if dataset.lower() in ['inn_fmnist', 'f-mnist']:
        test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset.lower() in ['cifar10', 'inn_cifar10']:
        test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset.lower() in ['cifar100', 'inn_cifar100']:
        test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    elif dataset.lower() in ['imgnet', 'inn_imgnet']:
        test_dataset = torchvision.datasets.imagenet.ImageFolder('./data/ILSVRC/Data/CLS-LOC/val', transform=transform)
    elif dataset.lower() in ['stl10', 'inn_stl10']:
        test_dataset = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)
    elif dataset.lower() in ['cars', 'inn_cars']:
        test_dataset = torchvision.datasets.ImageFolder(root='./data/stanford_cars/car_data/test', transform=transform)
    elif dataset.lower() in ['inn_opets', 'pets']:
        test_dataset = Pets(root='./data/pets', image_set='test', transform=transform)
    elif dataset.lower() in ['inn_tiny', 'tiny']:
        test_dataset = torchvision.datasets.ImageFolder(root='./data/tiny-64/test', transform=transform)
    elif dataset.lower() in ['cub200', 'inn_cub200']:
        test_dataset = Cub2011(root='./data/data', train=False, transform=transform)
    elif dataset.lower() in ['cub20', 'inn_cub20']:
        test_dataset = Cub2011_MOD(root='./data/cub20', train=False, transform=transform, download=False)
    elif dataset.lower() in ['inn_bmw10', 'bmw10']:
        test_dataset = BMW10(root='./data/bmw10', image_set='test', transform=transform)
    elif dataset.lower() == 'lsun':
        test_dataset = torchvision.datasets.LSUN(root='./data/lsun', classes='test', transform=transform)
    else:
        raise AssertionError('Dataset {} is currently not supported'.format(dataset))
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size
                                              , num_workers=num_workers, pin_memory=pin_memory, shuffle=False)
    return test_loader


class TransformComposer:
    def __init__(self, transforms, dataset, inp_size=32, re_size=None, rotation=15):
        self.transforms = transforms
        self.dataset = dataset
        self.inp_size = inp_size
        self.rotation = rotation
        self.re_size = re_size
        if re_size is None:
            self.re_size = inp_size

    def get_composite(self, num_bins=None, to_tensor=True, norm_bin=None):
        composite_transforms = []
        if norm_bin is None:
            norm_bin = num_bins
        for transform in self.transforms:
            if transform.lower() == 'crop':
                composite_transforms.append(transforms.RandomCrop(self.inp_size, padding=4))
                print('cropping enabled')
            if transform.lower() == 'ccrop':
                composite_transforms.append(transforms.CenterCrop(self.inp_size))
                print('cropping enabled')
            if transform.lower() == 'rcrop':
                composite_transforms.append(transforms.RandomCrop((self.inp_size, self.inp_size)))
            if transform.lower() == 'rrcrop':
                composite_transforms.append(transforms.RandomResizedCrop(self.inp_size))
            if transform.lower() == 'recrop':
                composite_transforms.append(transforms.Resize((self.inp_size, self.inp_size)))
            if transform.lower() == 'color':
                s = 0.5
                color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
                composite_transforms.append(transforms.RandomApply([color_jitter], p=0.8))
            if transform.lower() == 'rotation':
                composite_transforms.append(transforms.RandomRotation(15))
                print('rotation enabled')
            if transform.lower() == 'hflip':
                composite_transforms.append(transforms.RandomHorizontalFlip())
                print('horizontal flipping enabled')
            if transform.lower() == 'vflip':
                composite_transforms.append(transforms.RandomVerticalFlip())
                print('vertical flipping enabled')
            if transform.lower() == 'binning':
                composite_transforms.append(BinImageIntensities(num_bins=num_bins))
                print('binning enabled into {} bins'.format(num_bins))
            if transform.lower() == 'resize':
                composite_transforms.append(transforms.Resize((self.re_size, self.re_size)))
            if transform.lower() == 'tolabel':
                composite_transforms.append(ToLabel())
            if transform.lower() == 'relabel':
                composite_transforms.append(Relabel())
            if transform.lower() == 'grey':
                composite_transforms.append(transforms.Grayscale(num_output_channels=3))
        if 'TO_PIL' in self.transforms:
            composite_transforms.append(transforms.ToPILImage())
        if to_tensor:
            composite_transforms.append(transforms.ToTensor())
        if 'REPEAT' in self.transforms:
            composite_transforms.append((RepeatChannels()))
        if 'switch' in self.transforms or 'SWITCH' in self.transforms:
            composite_transforms.append(SwitchChannels([1, 2, 0]))
        if ('norm' in self.transforms or 'NORM' in self.transforms):
            print('Dataset Norm enabled for bin :{}'.format(norm_bin))
            composite_transforms.append(transforms.Normalize(NORM[self.dataset][str(norm_bin)][0], NORM[self.dataset][str(norm_bin)][1]))
        if 'random_switch' in self.transforms or 'RANDOM_SWITCH' in self.transforms:
            composite_transforms.append(RandomSwitchChannels())
        if 'NEGATIVE' in self.transforms or 'negative' in self.transforms:
            composite_transforms.append(NegativeImages())
        if 'PERMUTE' in self.transforms or 'permute' in self.transforms:
            composite_transforms.append(PermutePixels())

        return transforms.Compose(composite_transforms)
