import os
from data.base_dataset import BaseDataset, get_params, get_transform
from data.image_folder import make_dataset, make_dataset_all, make_dataset_all_text, make_dataset_3, make_dataset_5, make_dataset_6, make_dataset_4, make_dataset_2
from PIL import Image
from pathlib import Path
import numpy as np
import random
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import Augmentor
import cv2
import glob
class AlignedDataset_all(BaseDataset):
    """A dataset class for paired image dataset.

    It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.
    """

    def __init__(self, opt, image_size, augment_flip=True, equalizeHist=True, crop_patch=True, generation=False, task=None):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        self.equalizeHist = equalizeHist
        self.augment_flip = augment_flip
        self.crop_patch = crop_patch
        self.generation = generation
        self.image_size = image_size
        self.opt = opt
        # print(opt.dataroot)
        #origin----------------------------------------------------------------------------------------------------------
        self.dir_Arain = os.path.join(opt.dataroot, 'rain1400/' + opt.phase + '/rainy_image')
        self.dir_Brain = os.path.join(opt.dataroot, 'rain1400/' + opt.phase + '/ground_truth')
        self.dir_Alsrw = os.path.join(opt.dataroot, 'LSRW/' + opt.phase + '/low')
        self.dir_Blsrw = os.path.join(opt.dataroot, 'LSRW/' + opt.phase + '/high')
        self.dir_Alol = os.path.join(opt.dataroot, 'LOL/' + opt.phase + '/low')
        self.dir_Blol = os.path.join(opt.dataroot, 'LOL/' + opt.phase + '/high')
        
        if opt.phase == 'train':
            self.dir_Asnow = os.path.join(opt.dataroot, 'Snow100K/' + opt.phase + '/Snow100K-L'+'/synthetic')
            self.dir_Bsnow = os.path.join(opt.dataroot, 'Snow100K/' + opt.phase +'/Snow100K-L'+ '/gt')
            self.dir_Arain_syn = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/input')
            self.dir_Brain_syn = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/target')
            self.dir_Ablur = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/input')
            self.dir_Bblur = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/target')

            flog_prefix = os.path.join(opt.dataroot, 'RESIDE/OTS_ALPHA/')
            self.dir_Afog = flog_prefix + 'haze/OTS'
            self.dir_Bfog = flog_prefix + 'clear/clear_images'
        else:
            #####################
            self.dir_Arain = os.path.join(opt.dataroot, 'rain1400/' + opt.phase + '/rainy_image')
            self.dir_Brain = os.path.join(opt.dataroot, 'rain1400/' + opt.phase + '/ground_truth')
            self.dir_Alsrw = os.path.join(opt.dataroot, 'LSRW/' + opt.phase + '/low')
            self.dir_Blsrw = os.path.join(opt.dataroot, 'LSRW/' + opt.phase + '/high')
            self.dir_Alol = os.path.join(opt.dataroot, 'LOL/' + opt.phase + '/low')
            self.dir_Blol = os.path.join(opt.dataroot, 'LOL/' + opt.phase + '/high')
        
            ##################### snow ###########################
            self.dir_Asnow1 = os.path.join(opt.dataroot, 'Snow100K/' + opt.phase + '/Snow100K-S/synthetic') #Snow100K-S Snow100K-L
            self.dir_Bsnow1 = os.path.join(opt.dataroot, 'Snow100K/' + opt.phase + '/Snow100K-S/gt')
            self.dir_Asnow2 = os.path.join(opt.dataroot, 'Snow100K/' + opt.phase + '/Snow100K-L/synthetic') #Snow100K-S Snow100K-L
            self.dir_Bsnow2 = os.path.join(opt.dataroot, 'Snow100K/' + opt.phase + '/Snow100K-L/gt')

            ##################### rain ###########################
            self.dir_Arain_syn1 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Test2800/input') 
            self.dir_Brain_syn1 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Test2800/target')

            self.dir_Arain_syn2 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Rain100H/input') 
            self.dir_Brain_syn2 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Rain100H/target')

            self.dir_Arain_syn3 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Rain100L/input') 
            self.dir_Brain_syn3 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Rain100L/target')
  
            self.dir_Arain_syn4 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Test100/input') 
            self.dir_Brain_syn4 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Test100/target')
             
            self.dir_Arain_syn5 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Test1200/input') #Rain100H, Rain100L, Test100, Test1200,
            self.dir_Brain_syn5 = os.path.join(opt.dataroot, 'syn_rain/' + opt.phase + '/Test1200/target')  #Test2800

            self.dir_Ablur = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/GoPro/input') 
            self.dir_Bblur = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/GoPro/target')


            
            self.dir_Afog = os.path.join(opt.dataroot, 'RESIDE/SOTS/outdoor/hazy')
            self.dir_Bfog = os.path.join(opt.dataroot, 'RESIDE/SOTS/outdoor/gt')




            #real dark:
            self.dir_A_real_dark_mef=os.path.join(opt.dataroot,"real_dark/real_dark/MEF")
            self.dir_A_real_dark_npe=os.path.join(opt.dataroot,"real_dark/real_dark/NPE")
            self.dir_A_real_dark_dice=os.path.join(opt.dataroot,"real_dark/real_dark/DICE")

            #real rain:/data/xx/aio/dataset/all_in_one/real_rain/real_rain/Practical
            self.dir_A_real_rain=os.path.join(opt.dataroot,"real_rain/real_rain/Practical")

            #real snow:/data/xx/aio/dataset/all_in_one/Snow100K/realistic
            self.dir_A_real_snow=os.path.join(opt.dataroot,"Snow100K/realistic")


            self.dir_Ablur_hide = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/HIDE/input')  #GoPro, HIDE,  Reblur_J, Reblur_R
            self.dir_Bblur_hide = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/HIDE/target')

            self.dir_Ablur_j = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/RealBlur_J/input')  #GoPro, HIDE,  Reblur_J, Reblur_R
            self.dir_Bblur_j = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/RealBlur_J/target')

            self.dir_Ablur_r = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/RealBlur_R/input')  #GoPro, HIDE,  Reblur_J, Reblur_R
            self.dir_Bblur_r = os.path.join(opt.dataroot, 'Deblur/' + opt.phase + '/RealBlur_R/target')

            #

            #real udc
            # self.dir_A_real_udc=os.path.join(opt.dataroot,"real_rain/real_rain/Practical")
        #test
        if task == 'light':
            if opt.phase == 'train':
                self.A_paths = sorted(make_dataset_2(self.dir_Alol, self.dir_Alsrw, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset_2(self.dir_Blol, self.dir_Blsrw, opt.max_dataset_size))
            else:
                self.A_paths = sorted(make_dataset(self.dir_Alol, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset(self.dir_Blol, opt.max_dataset_size))
        elif task == 'light_only':
            self.A_paths = sorted(make_dataset(self.dir_Alol, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Blol, opt.max_dataset_size))
        elif task == 'rain':
            if opt.phase == 'train':
                self.A_paths = sorted(make_dataset(self.dir_Arain_syn, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset(self.dir_Brain_syn, opt.max_dataset_size))
            else:
                self.A_paths = sorted(make_dataset(self.dir_Arain_syn1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn2, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn3, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn4, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn5, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset(self.dir_Brain_syn1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn2, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn3, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn4, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn5, opt.max_dataset_size))
        
                # 1. 检查列表长度是否一致
                if len(self.A_paths) != len(self.B_paths):
                    raise ValueError(
                        f"文件列表长度不匹配: A有 {len(self.A_paths)} 个文件, "
                        f"B有 {len(self.B_paths)} 个文件。"
                    )

                # 2. 逐一检查文件名是否对应
                print("正在检查 A 和 B 列表中的文件名是否对应...")
                mismatched_files = []
                for path_a, path_b in zip(self.A_paths, self.B_paths):
                    # 获取 A 的基本文件名（不含扩展名）
                    # os.path.basename('/path/to/file.png') -> 'file.png'
                    # os.path.splitext('file.png') -> ('file', '.png')
                    # [0] -> 'file'
                    basename_a = os.path.splitext(os.path.basename(path_a))[0]
                    
                    # 获取 B 的基本文件名（不含扩展名）
                    basename_b = os.path.splitext(os.path.basename(path_b))[0]

                    # 比较基本文件名
                    if basename_a != basename_b:
                        mismatched_files.append((path_a, path_b))

                # 3. 如果有不匹配的，则报错并显示前5个示例
                if mismatched_files:
                    error_msg = f"错误: 检测到 {len(mismatched_files)} 个文件名不匹配。\n"
                    error_msg += "不匹配的文件示例 (最多显示5个):\n"
                    for a, b in mismatched_files[:5]:
                        error_msg += f"  - A: {os.path.basename(a)} (来自 {a})\n"
                        error_msg += f"  - B: {os.path.basename(b)} (来自 {b})\n"
                        error_msg += "---\n"
                    raise ValueError(error_msg)
                
                print(f"文件校验通过：所有 {len(self.A_paths)} 个文件的基本名均一一对应。")
        elif task == 'rain1':
            self.A_paths = sorted(make_dataset(self.dir_Arain_syn1, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Brain_syn1, opt.max_dataset_size))
        elif task == 'rain2':
            self.A_paths = sorted(make_dataset(self.dir_Arain_syn2, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Brain_syn2, opt.max_dataset_size))
        elif task == 'rain3':
            self.A_paths = sorted(make_dataset(self.dir_Arain_syn3, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Brain_syn3, opt.max_dataset_size))
        elif task == 'rain4':
            self.A_paths = sorted(make_dataset(self.dir_Arain_syn4, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Brain_syn4, opt.max_dataset_size))
        elif task == 'rain5':
            self.A_paths = sorted(make_dataset(self.dir_Arain_syn5, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Brain_syn5, opt.max_dataset_size))
        
        elif task == 'snow1':
            if opt.phase == 'train':
                 self.A_paths = sorted(make_dataset(self.dir_Asnow, opt.max_dataset_size))
                 self.B_paths = sorted(make_dataset(self.dir_Bsnow, opt.max_dataset_size))
            else:
                self.A_paths = sorted(make_dataset(self.dir_Asnow1, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset(self.dir_Bsnow1, opt.max_dataset_size))
        elif task == 'snow2':
            if opt.phase == 'train':
                 self.A_paths = sorted(make_dataset(self.dir_Asnow, opt.max_dataset_size))
                 self.B_paths = sorted(make_dataset(self.dir_Bsnow, opt.max_dataset_size))
            else:
                self.A_paths = sorted(make_dataset(self.dir_Asnow2, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset(self.dir_Bsnow2, opt.max_dataset_size))
        elif task == 'snow':
            if opt.phase == 'train':
                 self.A_paths = sorted(make_dataset(self.dir_Asnow, opt.max_dataset_size))
                 self.B_paths = sorted(make_dataset(self.dir_Bsnow, opt.max_dataset_size))
            else:
                self.A_paths = sorted(make_dataset(self.dir_Asnow1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Asnow2, opt.max_dataset_size))
                self.B_paths = sorted(make_dataset(self.dir_Bsnow1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Bsnow2, opt.max_dataset_size))
            # print(self.A_paths,self.dir_Asnow)
        elif task == 'blur':
            self.A_paths = sorted(make_dataset(self.dir_Ablur, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Bblur, opt.max_dataset_size))
        elif task == 'fog':
            self.A_paths = sorted(make_dataset(self.dir_Afog, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Bfog, opt.max_dataset_size))
            if opt.phase!="train":
                files_a = sorted(glob.glob(os.path.join(self.dir_Bfog, "*.png")))
                # 读取 b 文件夹里的 xx_yyy.jpg
                files_b = sorted(glob.glob(os.path.join(self.dir_Afog, "*.jpg")))
                # print(files_a)
                self.A_paths, self.B_paths = [], []

                # 把文件名去掉扩展名，方便匹配
                dict_b = {}
                for fb in files_b:
                    name = os.path.basename(fb).split("_")[0]  # 提取 xx
                    dict_b.setdefault(name, []).append(fb)

                for fa in files_a:
                    name = os.path.splitext(os.path.basename(fa))[0]  # 提取 xx
                    if name in dict_b:
                        for fb in dict_b[name]:
                            self.A_paths.append(fb)
                            self.B_paths.append(fa)
        elif task=="all":
            if opt.phase!="train":
                # --- 1. 'light' (test) ---
                self.A_paths=[]
                self.B_paths=[]
                self.A_paths += sorted(make_dataset(self.dir_Alol, opt.max_dataset_size))
                self.B_paths += sorted(make_dataset(self.dir_Blol, opt.max_dataset_size))
                
                # --- 2. 'light_only' (test) --- 
                # (test 模式下, 'light_only' 和 'light' 的 else 逻辑完全相同, 故省略避免重复)

                # --- 3. 'rain' (test) ---
                self.A_paths += sorted(make_dataset(self.dir_Arain_syn1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn2, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn3, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn4, opt.max_dataset_size))+sorted(make_dataset(self.dir_Arain_syn5, opt.max_dataset_size))
                self.B_paths += sorted(make_dataset(self.dir_Brain_syn1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn2, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn3, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn4, opt.max_dataset_size))+sorted(make_dataset(self.dir_Brain_syn5, opt.max_dataset_size))

                # --- 4. 'snow' (test) ---
                self.A_paths += sorted(make_dataset(self.dir_Asnow1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Asnow2, opt.max_dataset_size))
                self.B_paths += sorted(make_dataset(self.dir_Bsnow1, opt.max_dataset_size))+sorted(make_dataset(self.dir_Bsnow2, opt.max_dataset_size))

                # --- 5. 'blur' (train/test 逻辑相同) ---
                self.A_paths += sorted(make_dataset(self.dir_Ablur, opt.max_dataset_size))
                self.B_paths += sorted(make_dataset(self.dir_Bblur, opt.max_dataset_size))

                # --- 6. 'fog' (test) ---
                # (复制 'fog' 的 else 逻辑, 但不重置列表, 而是用临时列表来 append)
                files_a_fog = sorted(glob.glob(os.path.join(self.dir_Bfog, "*.png")))
                files_b_fog = sorted(glob.glob(os.path.join(self.dir_Afog, "*.jpg")))
                
                A_paths_fog_test = [] # 临时列表
                B_paths_fog_test = [] # 临时列表

                dict_b_fog = {}
                for fb in files_b_fog:
                    name = os.path.basename(fb).split("_")[0]
                    dict_b_fog.setdefault(name, []).append(fb)

                for fa in files_a_fog:
                    name = os.path.splitext(os.path.basename(fa))[0]
                    if name in dict_b_fog:
                        for fb in dict_b_fog[name]:
                            A_paths_fog_test.append(fb)
                            B_paths_fog_test.append(fa)
                
                # 将 'fog' (test) 的结果追加到主列表
                self.A_paths += A_paths_fog_test
                self.B_paths += B_paths_fog_test
        
        elif task == '4':
            self.A_paths = sorted(make_dataset_4(self.dir_Arain_syn, self.dir_Alsrw, self.dir_Alol, self.dir_Asnow, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset_4(self.dir_Brain_syn, self.dir_Blsrw, self.dir_Blol, self.dir_Bsnow, opt.max_dataset_size))
        elif task == '5':
            self.A_paths = sorted(make_dataset_5(self.dir_Arain_syn, self.dir_Alsrw, self.dir_Alol, self.dir_Asnow, self.dir_Ablur, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset_5(self.dir_Brain_syn, self.dir_Blsrw, self.dir_Blol, self.dir_Bsnow, self.dir_Bblur, opt.max_dataset_size))
        elif task == '6':
            self.A_paths = sorted(make_dataset_6(self.dir_Arain_syn, self.dir_Alol, self.dir_Asnow, self.dir_Ablur, self.dir_Afog, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset_6(self.dir_Brain_syn, self.dir_Blol, self.dir_Bsnow, self.dir_Bblur, self.dir_Bfog, opt.max_dataset_size))
        elif task=='real_dark_mef':
            self.A_paths=make_dataset(self.dir_A_real_dark_mef, opt.max_dataset_size)
            self.B_paths=self.A_paths[:]
        elif task=='real_dark_dice':
            self.A_paths=make_dataset(self.dir_A_real_dark_dice, opt.max_dataset_size)
            self.B_paths=self.A_paths[:]
        elif task=='real_dark_npe':
            self.A_paths=make_dataset(self.dir_A_real_dark_npe, opt.max_dataset_size)
            self.B_paths=self.A_paths[:]
        elif task=='real_rain':
            self.A_paths=make_dataset(self.dir_A_real_rain, opt.max_dataset_size)
            self.B_paths=self.A_paths[:]
        elif task=='real_snow':
            self.A_paths=make_dataset(self.dir_A_real_snow, opt.max_dataset_size)
            self.B_paths=self.A_paths[:]
        elif task=='real_hide':
            self.A_paths=make_dataset(self.dir_Ablur_hide, opt.max_dataset_size)
            self.B_paths=make_dataset(self.dir_Bblur_hide, opt.max_dataset_size)
        elif task=='real_j':
            self.A_paths=make_dataset(self.dir_Ablur_j, opt.max_dataset_size)
            self.B_paths=make_dataset(self.dir_Bblur_j, opt.max_dataset_size)
        elif task=='real_r':
            self.A_paths=make_dataset(self.dir_Ablur_r, opt.max_dataset_size)
            self.B_paths=make_dataset(self.dir_Bblur_r, opt.max_dataset_size)
        else:
            self.A_paths = sorted(make_dataset(self.dir_Aasd, opt.max_dataset_size))
            self.B_paths = sorted(make_dataset(self.dir_Basd, opt.max_dataset_size))
    

        self.A_size = len(self.A_paths)  # get the size of dataset A
        print(self.A_size,task)
        self.B_size = len(self.B_paths)  # get the size of dataset B
        print(self.B_size,task)
        assert(self.opt.load_size >= self.opt.crop_size)   # crop_size should be smaller than the size of loaded image

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        A_path = self.A_paths[index % self.A_size]  # make sure index is within then range
        B_path = self.B_paths[index % self.B_size]
        # print(A_path,B_path)
        condition = Image.open(A_path).convert('RGB') #condition
        gt = Image.open(B_path).convert('RGB') #gt
        
        if 'LOL' in A_path or 'LSRW' in A_path or 'dark' in A_path:
            condition = cv2.cvtColor(np.asarray(condition), cv2.COLOR_RGB2BGR)
            gt = cv2.cvtColor(np.asarray(gt), cv2.COLOR_RGB2BGR)
        
            if self.crop_patch:
                gt, condition = self.get_patch([gt, condition], self.image_size)
            if 'LOL' in A_path or 'dark' in A_path:
                condition = self.cv2equalizeHist(condition) if self.equalizeHist else condition
            else:
                condition = condition

            images = [[gt, condition]]
            p = Augmentor.DataPipeline(images)
            if self.augment_flip:
                p.flip_left_right(1)
            g = p.generator(batch_size=1)
            augmented_images = next(g)
            gt = cv2.cvtColor(augmented_images[0][0], cv2.COLOR_BGR2RGB)
            condition = cv2.cvtColor(augmented_images[0][1], cv2.COLOR_BGR2RGB)
        
            gt = self.to_tensor(gt)
            condition = self.to_tensor(condition)
        else:
            w, h = condition.size
            transform_params = get_params(self.opt, condition.size)
            A_transform = get_transform(self.opt, transform_params, grayscale=False)
            B_transform = get_transform(self.opt, transform_params, grayscale=False)
            condition = A_transform(condition)
            gt = B_transform(gt)
            if self.opt.phase == 'train':
                if h < 256 or w < 256:
                    osize = [256, 256]
                    resi = transforms.Resize(osize, transforms.InterpolationMode.BICUBIC)
                    condition = resi(condition)
                    gt = resi(gt)
        # print(condition.shape)
        return {'adap': condition, 'gt': gt, 'A_paths': A_path, 'B_paths': B_path}

    def __len__(self):
        """Return the total number of images in the dataset."""
        return max(self.A_size, self.B_size)
    
    def load_flist(self, flist):
        if isinstance(flist, list):
            return flist

        # flist: image file path, image directory path, text file flist path
        if isinstance(flist, str):
            if os.path.isdir(flist):
                return [p for ext in self.exts for p in Path(f'{flist}').glob(f'**/*.{ext}')]

            if os.path.isfile(flist):
                try:
                    return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
                except:
                    return [flist]
        return []

    def cv2equalizeHist(self, img):
        (b, g, r) = cv2.split(img)
        b = cv2.equalizeHist(b)
        g = cv2.equalizeHist(g)
        r = cv2.equalizeHist(r)
        img = cv2.merge((b, g, r))
        return img

    def to_tensor(self, img):
        img = Image.fromarray(img)  # returns an image object.
        img_t = TF.to_tensor(img).float()
        return img_t

    def load_name(self, index, sub_dir=False):
        if self.condition:
            # condition
            name = self.input[index]
            if sub_dir == 0:
                return os.path.basename(name)
            elif sub_dir == 1:
                path = os.path.dirname(name)
                sub_dir = (path.split("/"))[-1]
                return sub_dir+"_"+os.path.basename(name)

    def get_patch(self, image_list, patch_size):
        i = 0
        h, w = image_list[0].shape[:2]
        rr = random.randint(0, h-patch_size)
        cc = random.randint(0, w-patch_size)
        for img in image_list:
            image_list[i] = img[rr:rr+patch_size, cc:cc+patch_size, :]
            i += 1
        return image_list

    def pad_img(self, img_list, patch_size, block_size=8):
        i = 0
        for img in img_list:
            img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
            h, w = img.shape[:2]
            bottom = 0
            right = 0
            if h < patch_size:
                bottom = patch_size-h
                h = patch_size
            if w < patch_size:
                right = patch_size-w
                w = patch_size
            bottom = bottom + (h // block_size) * block_size + \
                (block_size if h % block_size != 0 else 0) - h
            right = right + (w // block_size) * block_size + \
                (block_size if w % block_size != 0 else 0) - w
            img_list[i] = cv2.copyMakeBorder(
                img, 0, bottom, 0, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
            i += 1
        return img_list

    def get_pad_size(self, index, block_size=8):
        img = Image.open(self.input[index])
        patch_size = self.image_size
        img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
        h, w = img.shape[:2]
        bottom = 0
        right = 0
        if h < patch_size:
            bottom = patch_size-h
            h = patch_size
        if w < patch_size:
            right = patch_size-w
            w = patch_size
        bottom = bottom + (h // block_size) * block_size + \
            (block_size if h % block_size != 0 else 0) - h
        right = right + (w // block_size) * block_size + \
            (block_size if w % block_size != 0 else 0) - w
        return [bottom, right]
