
import glob
import numpy as np
from PIL import Image
import sys
import torch
from torchvision import transforms
from torchvision.datasets import VisionDataset
from torchvision.transforms.functional import pad

###
# Transforms
###

def get_padding(image):
    w, h = image.size
    max_wh = np.max([w, h])
    h_padding = (max_wh - w) / 2
    v_padding = (max_wh - h) / 2
    l_pad = h_padding if h_padding % 1 == 0 else h_padding+0.5
    t_pad = v_padding if v_padding % 1 == 0 else v_padding+0.5
    r_pad = h_padding if h_padding % 1 == 0 else h_padding-0.5
    b_pad = v_padding if v_padding % 1 == 0 else v_padding-0.5
    padding = (int(l_pad), int(t_pad), int(r_pad), int(b_pad))
    return padding

class MakeSquare(object):
        
    def __call__(self, img):
        return pad(img, get_padding(img), 0, 'constant')
    
    def __repr__(self):
        return self.__class__.__name__
        
def get_transform(mode = 'normalize'):
    if mode == 'normalize':
        return transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])
    elif mode == 'reshape':
        return transforms.Compose([
                MakeSquare(),
                transforms.Resize((224,224))
                ])
    elif mode == 'full':
        return transforms.Compose([
                MakeSquare(),
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])
    elif mode == 'imagenet':
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    elif mode == 'resize-crop':
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            ])

###
# Datasets
###

def get_loader(dataset, batch_size = 64, num_workers = 0):
    return torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, pin_memory = True)

class ImageDataset(VisionDataset):
    
    def __init__(self, filenames, labels, transform_mode = 'normalize', get_names = False,):
        transform = get_transform(mode = transform_mode)
        super(ImageDataset, self).__init__(None, None, transform, None)
        self.filenames = filenames
        self.labels = labels
        self.get_names = get_names
        
    def __getitem__(self, index):
        filename = self.filenames[index]
        img = Image.open(filename).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[index]
        if self.get_names:
            return img, label, filename
        else:
            return img, label
        
    def __len__(self):
        return len(self.filenames)
    
class DirectoryDataset(torch.utils.data.Dataset):
    def __init__(self, directory, transform_mode = 'normalize'):
        transform = get_transform(mode = transform_mode)
        super(DirectoryDataset, self).__init__()
        self.filenames = glob.glob('{}/*'.format(directory))
        self.transform = transform

    def __getitem__(self, index):
        filename = self.filenames[index]
        img = Image.open(filename).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, filename

    def __len__(self):
        return len(self.filenames)
    