from glob import glob
import random

import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image


class SRSet(Dataset):
    def __init__(self, dir=None, crop_size=(64, 64), repeat=4, do_normalize=True, cache='none') -> None:
        super(SRSet, self).__init__()

        self.hr_paths = glob(dir + '/DIV2K_train_HR/*.png')
        self.hr_paths.sort()
        self.lr_paths = glob(dir + '/DIV2K_train_LR_bicubic/X4/*.png')
        self.lr_paths.sort()

        self.crop_size = crop_size
        self.repeat = repeat
        self.do_normalize = do_normalize
        self.cache = cache
        self.num_files = len(self.hr_paths)

        self.to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        if do_normalize:
            self.normalize = lambda x: normalize(x)
        else:
            self.normalize = lambda x: x

        if cache == 'in_memory':
            self.lr_files = []
            self.hr_files = []
            for lr_path, hr_path in zip(self.lr_paths, self.hr_paths):
                lr_img = self.load_img(lr_path)
                self.lr_files.append(lr_img)
                hr_img = self.load_img(hr_path)
                self.hr_files.append(hr_img)

    def __len__(self) -> int:
        return len(self.hr_paths) * self.repeat

    def __getitem__(self, index):
        if self.cache == 'none':
            lr_path = self.lr_paths[index % self.num_files]
            hr_path = self.hr_paths[index % self.num_files]
            lr_img = self.load_img(lr_path)
            hr_img = self.load_img(hr_path)
        elif self.cache == 'in_memory':
            lr_img = self.lr_files[index % self.num_files]
            hr_img = self.hr_files[index % self.num_files]

        crop_h, crop_w = self.crop_size
        scale = hr_img.shape[-1] // lr_img.shape[-1]
        crop_h0 = random.randint(0, lr_img.shape[1] - crop_h)
        crop_w0 = random.randint(0, lr_img.shape[2] - crop_w)
        lr_img = lr_img[:, crop_h0 : crop_h0 + crop_h, crop_w0 : crop_w0 + crop_w]

        crop_h0 *= scale
        crop_w0 *= scale
        crop_h, crop_w = crop_h * scale, crop_w * scale
        hr_img = hr_img[:, crop_h0 : crop_h0 + crop_h, crop_w0 : crop_w0 + crop_w]

        hflip = random.random() < 0.5
        vflip = random.random() < 0.5
        dflip = random.random() < 0.5

        def augment(x):
            if hflip:
                x = x.flip(-2)
            if vflip:
                x = x.flip(-1)
            if dflip:
                x = x.transpose(-2, -1)
            return x

        lr_img, hr_img = augment(lr_img), augment(hr_img)

        return lr_img, hr_img

    def load_img(self, path):
        img = Image.open(path)
        img = self.to_tensor(img)
        img = img * 255
        img = img.type(torch.uint8)
        return img
