#%%
import torch
from torch.utils.data import Dataset
import numpy as np
import cv2
import os
from torchvision import transforms
import glob
import random

class Normalize(object):
    def __init__(self, ZeroToOne=False):
        super(Normalize, self).__init__()
        self.ZeroToOne = ZeroToOne
        self.num = 0 if ZeroToOne else 0.5

    def __call__(self, data):
        for key in data.keys():
            if key != 'flow':
                data[key] = ((data[key] / 255) - self.num).copy()
        return data

class ToTensor(object):
    def __call__(self, data):
        for key in data.keys():
            data[key] = torch.from_numpy(data[key].transpose((2, 0, 1))).clone()
        return data
           
class Test_Loader(Dataset):
    def __init__(self, data_path=None, crop_size=None, ZeroToOne=False):
        assert data_path , "must have one dataset path !"
        self.blur_list = []
        self.sharp_list = []
        self.is_sharp_dir = os.path.isdir(os.path.join(data_path, "target"))


        self.transform = transforms.Compose([Normalize(ZeroToOne), ToTensor()])

        if data_path:
            self.blur_list.extend(sorted(glob.glob(os.path.join(data_path, "input", '*.png'))))
            if self.is_sharp_dir:
                self.sharp_list.extend(sorted(glob.glob(os.path.join(data_path, "target", '*.png'))))
        
        if self.is_sharp_dir:
            assert len(self.sharp_list) == len(self.blur_list), "Missmatched Length!"

    def __len__(self):
        return len(self.blur_list)

    def __getitem__(self, idx):
        blur = cv2.imread(self.blur_list[idx]).astype(np.float32)
        blur = cv2.cvtColor(blur, cv2.COLOR_BGR2RGB)
        if self.is_sharp_dir:
            sharp = cv2.imread(self.sharp_list[idx]).astype(np.float32)
            sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2RGB)

            sample = {'blur': blur,
                    'sharp': sharp}
        else:
            sample = {'blur': blur}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
    def get_path(self, idx):
        if self.is_sharp_dir:
            return {'blur_path': self.blur_list[idx], 'sharp_path': self.sharp_list[idx]}
        else:
            return {'blur_path': self.blur_list[idx]}

def get_image(path):
    transform = transforms.Compose([Normalize(), ToTensor()])
    image = cv2.imread(path).astype(np.float32)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    sample = {'image': image}
    sample = transform(sample)

    return sample['image']

