from PIL import Image
import glob
import os
import pdb
import random
from utils import utils_blindsr as blindsr
import torchvision.transforms as transforms
import torch


if __name__ == "__main__":
    toTensor=transforms.ToTensor()
    toPIL=transforms.ToPILImage()
    root_naiveDIV2K_validset=os.path.expanduser('/mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2K_valid_HR')
    root_DIV2KRK=os.path.expanduser('/mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2KRK/')

    root_HR = '/mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2K4D/HR/'
    root_PART1='/mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2K4D/LR_x4/part1/'  # naive bicubic (downsample by 4)
    root_PART2 ='/mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2K4D/LR_x4/part2/' # directly borrowed from DIV2KRK
    root_PART3 ='/mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2K4D/LR_x4/part3/' # downsample by 2 from DIV2KRK x2
    root_PART4 ='//mnt/server5_hard1/seungjun/KAIR-master/testsets/DIV2K4D/LR_x4/part4/' # BSRGAN degradation model
    
    for filename in glob.glob(root_naiveDIV2K_validset + '/*.png'):  #  naive HR + part1
        img_=Image.open(filename).convert('RGB')

        img_name=filename.split('/')
        img_name=img_name[-1]
        img_name=img_name[:-4]+'.jpg'
                
        save_dir_PART4 = root_PART4+img_name
        img_=toTensor(img_)
        img_=img_.permute(1,2,0)
        img_lq, img_hq = blindsr.degradation_bsrgan(img_.numpy(), sf=4, lq_patchsize=72)
        
        img_lq=torch.from_numpy(img_lq)
        img_lq=img_lq.permute(2,0,1)
        img_lq=toPIL(img_lq)
        img_lq.save(save_dir_PART4)
        print("saving {}".format(save_dir_PART4))
        
        
        # add part4 imgs here
        

    
"""
    for filename in glob.glob(root_naiveDIV2K_validset + '/*.png'):  #  naive HR + part1
        img_=Image.open(filename).convert('RGB')

        img_name=filename.split('/')
        img_name=img_name[-1]
        
        save_dir_HR=root_HR+img_name
        img_.save(save_dir_HR)
        img_down = img_.resize((img_.width//4, img_.height//4))
        
        save_dir_PART1 = root_PART1+img_name
        print("saving {}".format(save_dir_PART1))
        img_down.save(save_dir_PART1)
        
        
        # add part4 imgs here
        
    for filename in glob.glob(root_DIV2KRK + 'lr_x4/*.png'):  #   part2 (*better to sudo cp, but just importing here...)
        img_=Image.open(filename).convert('RGB')
        
        img_name=filename.split('/')
        img_name=img_name[-1]
        
        save_dir_PART2 = root_PART2+img_name
        
        print("saving {}".format(save_dir_PART2))
        img_.save(save_dir_PART2)
        
    for filename in glob.glob(root_DIV2KRK + 'lr_x2/*.png'):  #   part3, get x2 ,  downsample by 2, and jpeg compress
        img_=Image.open(filename).convert('RGB')
        
        img_name=filename.split('/')
        img_name=img_name[-1]
        img_name=img_name[:-4]+'.jpg'
        
        save_dir_PART3 = root_PART3+img_name
        img_down = img_.resize((img_.width//2, img_.height//2))
        quality_factor = random.sample(range(41,91), 1)[0]
        
        print("saving {}".format(save_dir_PART3))
        img_down.save(save_dir_PART3,"JPEG", quality=quality_factor, optimize=True, progressive=True)
"""

    
        
    
    