import torch
from torch.utils import data
import imp
import os
from torchvision import datasets
from torchvision.transforms import transforms
from torchvision.transforms import  *
from PIL import Image
import random
import math
import numpy as np

class RandomErasing(object):
    '''
    Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 
    -------------------------------------------------------------------------------------
    probability: The probability that the operation will be performed.
    sl: min erasing area
    sh: max erasing area
    r1: min aspect ratio
    mean: erasing value
    -------------------------------------------------------------------------------------
    '''
    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
        self.probability = probability
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
    def __call__(self, img):
        if random.uniform(0, 1) > self.probability:
            return img
        for attempt in range(100):
            area = img.size()[1] * img.size()[2]
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1/self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w < img.size()[2] and h < img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
                else:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                return img

        return img


class ImageFolderWithPaths(datasets.ImageFolder):

    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple   + (path,))
        return tuple_with_path


def get_dataloader(conf):
    dataset = conf.dataset
    resize = 256
    cropsize = 224
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if resize == cropsize:
        tflist = [transforms.RandomResizedCrop(cropsize)]
    else:
        tflist = [transforms.Resize(resize),transforms.RandomCrop(cropsize)]

    transform_train = transforms.Compose(tflist + [
                transforms.RandomRotation(15),
                transforms.RandomCrop(cropsize),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize
    ])

    transform_test = transforms.Compose([
                             transforms.Resize(resize),
                             transforms.CenterCrop(cropsize),
                             transforms.ToTensor(),
                             normalize
                             ])

    if dataset == 'cub':
       batch_size = conf.batch_size
       data_path = r'/home/chaimb/AmeenNips21/Ours/datasets/dataset'
       trainset = ImageFolderWithPaths(root=data_path + '/train' , transform=transform_train)
       trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
       testset = ImageFolderWithPaths(root=data_path + '/test', transform=transform_test)
       testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True, num_workers=1)
       return trainloader , testloader
    if dataset == 'ayahoo':
       batch_size = conf.batch_size
       data_path = r'/home/chaimb/AmeenNips21/Ours/datasets/aYahoo/'
       trainset = ImageFolderWithPaths(root=data_path + '/train' , transform=transform_train)
       trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
       testset = ImageFolderWithPaths(root=data_path + '/test', transform=transform_test)
       testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True, num_workers=1)
       return trainloader , testloader
    if dataset == 'apascal':
       batch_size = conf.batch_size
       data_path = r'/home/chaimb/AmeenNips21/Ours/datasets/VOC2008/dataset'
       trainset = ImageFolderWithPaths(root=data_path + '/train' , transform=transform_train)
       trainloader = torch.utils.data.DataLoader(trainset, batch_size=2*batch_size, shuffle=True, num_workers=4)
       testset = ImageFolderWithPaths(root=data_path + '/test', transform=transform_test)
       testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=True, num_workers=1)
       return trainloader , testloader


    src_file = os.path.join('datasets',conf.dataset+'.py')
    dataimp = imp.load_source('loader',src_file)
    ds_train,ds_test = dataimp.get_dataset(conf)
    if 'trainshuffle' in conf:
        trainshuffle = conf.trainshuffle
    else:
        trainshuffle = True

    print('train shuffle:',trainshuffle)
    train_loader = data.DataLoader(ds_train, batch_size=conf.batch_size, shuffle= trainshuffle, num_workers=conf.workers, pin_memory=True)
    val_loader =data.DataLoader(ds_test, batch_size=conf.batch_size, shuffle= False, num_workers=conf.workers, pin_memory=True)

    return train_loader,val_loader









