import glob
import torch
import random

import torch.nn.functional as F
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset

class UnpairedImgDataset(Dataset):
    def __init__(self, data_source, mode, random_resize, crop=256):
        
        self.random_resize = random_resize
        self.crop = crop
        self.mode = mode
        self.transform_train = transforms.Compose([
            transforms.Resize(int(crop * 1.12), Image.BICUBIC),
            transforms.RandomCrop((crop, crop)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        self.transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        if mode == 'train':
            self.imgA_paths = sorted(glob.glob(data_source + '/' + mode + '/blurry' + '/*/*.*'))
            self.imgB_paths = sorted(glob.glob(data_source + '/'  + mode + '/sharp' + '/*/*.*'))

            self.imgA_leading = True if len(self.imgA_paths) >= len(self.imgB_paths) else False
        elif mode == 'val1':
            self.imgA_paths = sorted(glob.glob(data_source + '/blurry' + '/*/*.*'))
            self.imgB_paths = sorted(glob.glob(data_source + '/sharp' + '/*/*.*'))
        elif mode == 'val2':
            self.imgA_paths = sorted(glob.glob(data_source + '/low' + '/*.*'))
            self.imgB_paths = sorted(glob.glob(data_source + '/high' + '/*.*'))
        elif mode == 'test':
            self.imgA_paths = sorted(glob.glob(data_source + '/low' + '/*.*'))
            self.imgB_paths = sorted(glob.glob(data_source + '/high' + '/*.*'))
            # self.imgA_paths = sorted(glob.glob(data_source + '/blurry/*/*.*'))
            # self.imgB_paths = sorted(glob.glob(data_source + '/sharp/*/*.*'))
            # self.imgA_paths = sorted(glob.glob(data_source + '/' + 'test' + '/blurry' + '/*/*.*'))
            # self.imgB_paths = sorted(glob.glob(data_source + '/'  + 'test' + '/sharp' + '/*/*.*'))
            # self.imgA_paths = sorted(glob.glob(data_source + '/' + 'test' + '/low_blur' + '/*/*.*'))
            # self.imgB_paths = sorted(glob.glob(data_source + '/'  + 'test' + '/high_sharp_scaled' + '/*/*.*'))

    def __getitem__(self, index):
        if self.mode == 'train':
            if self.imgA_leading:
                imgA = Image.open(self.imgA_paths[index % len(self.imgA_paths)]).convert('RGB')
                imgB = Image.open(self.imgB_paths[random.randint(0, len(self.imgB_paths) - 1)]).convert('RGB')
            else:
                imgA = Image.open(self.imgA_paths[random.randint(0, len(self.imgA_paths) - 1)]).convert('RGB')
                imgB = Image.open(self.imgB_paths[index % len(self.imgB_paths)]).convert('RGB')
            
            # random resize
            if self.random_resize:
                width, height = imgA.size
                
                short_side = random.randint(self.crop, min(width, height))
                if width < height:
                    new_size = (short_side, int(short_side * height / width))
                else:
                    new_size = (int(short_side * width / height), short_side)
                
                imgA = imgA.resize(new_size)
                imgB = imgB.resize(new_size)
            
            imgA = self.transform_train(imgA)
            imgB = self.transform_train(imgB)
        else:
            imgA = Image.open(self.imgA_paths[index % len(self.imgA_paths)]).convert('RGB')
            imgB = Image.open(self.imgB_paths[index % len(self.imgB_paths)]).convert('RGB')

            imgA = self.transform_val(imgA)
            imgB = self.transform_val(imgB)
        
        return imgA, imgB

    def __len__(self):
        return max(len(self.imgA_paths), len(self.imgB_paths))
    
class SingleImgDataset(Dataset):
    def __init__(self, data_source):
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        self.img_paths = sorted(glob.glob(data_source + '/' + 'test' + '/blurry' + '/*/*.*'))

    def __getitem__(self, index):
        
        path = self.img_paths[index % len(self.img_paths)]
        
        img = Image.open(path).convert('RGB')
        
        img = self.transform(img)
        
        return img, path

    def __len__(self):
        return len(self.img_paths)