import torch.utils.data as data
from torchvision import transforms
from PIL import Image
import os
import torch
import numpy as np
import random
import torch.nn.functional as F

from .utils.mask import (bbox2mask, brush_stroke_mask, get_irregular_mask, random_bbox, random_cropping_bbox)
from .utils.auto_augment import ImageNetAutoAugment

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir, cond=None):
    if os.path.isfile(dir):
        images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')]
    else:
        images = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir
        for root, _, fnames in sorted(os.walk(dir)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    if cond == None:
                        images.append(path)
                    elif cond in path:
                        images.append(path)

    return images

def pil_loader(path):
    return Image.open(path).convert('RGB')

data_aug = ImageNetAutoAugment()

class InpaintDataset(data.Dataset):
    def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[256, 256], is_train=False):
        imgs = make_dataset(data_root)
        if data_len > 0:
            self.imgs = sorted(imgs)[:int(data_len)]
        else:
            self.imgs = imgs
            
        self.tfs = transforms.Compose([
                # transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        
        self.data_root = data_root
        self.mask_config = mask_config
        self.mask_root = self.mask_config["root"]
        self.mask_mode = self.mask_config['mask_mode']
        self.image_size = image_size
        self.is_train = is_train

    def __getitem__(self, index):
        ret = {}
        path = self.imgs[index]
        img = Image.open(path).convert("RGB").resize(self.image_size, resample=Image.Resampling.BILINEAR)
        
        # if self.is_train and random.random() > 0.5:
        #     img = img.transpose(Image.FLIP_LEFT_RIGHT)
            
        if self.is_train:
            img = data_aug(img)
        
        img = self.tfs(img)
        mask = self.get_mask(index)
        
        ret['image'] = img
        ret['con_image'] = img*(1-mask) + torch.randn_like(img)*mask
        ret['mask_image'] = img*(1-mask) + mask
        ret['mask'] = mask
        ret['name'] = "".join(i for i in path.split("/")[-1].split(".")[:-1])
        return ret

    def __len__(self):
        return len(self.imgs)

    def get_mask(self, index=0):
        if self.mask_mode == 'manual':
            mask = torch.from_numpy(np.array([np.asarray(Image.open(self.mask_root+f"{index % 500}.png"))]))
            return mask
        elif self.mask_mode == 'bbox':
            mask = bbox2mask(self.image_size, random_bbox())
        elif self.mask_mode == 'center':
            h, w = self.image_size
            mask = bbox2mask(self.image_size, (h//4, w//4, h//2, w//2))
        elif self.mask_mode == 'irregular':
            mask = get_irregular_mask(self.image_size)
        elif self.mask_mode == 'free_form':
            mask = brush_stroke_mask(self.image_size)
        elif self.mask_mode == 'hybrid':
            regular_mask = bbox2mask(self.image_size, random_bbox(self.image_size, \
                                    (self.image_size[0] // 2, self.image_size[1] // 2)))
            irregular_mask = brush_stroke_mask(self.image_size, )
            mask = regular_mask | irregular_mask
        elif self.mask_mode == 'file':
            pass
        else:
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.')
        return torch.from_numpy(mask).permute(2,0,1)
    
class UncroppingDataset(data.Dataset):
    def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[512, 512], is_train=False):
        imgs = make_dataset(data_root)
        if data_len > 0:
            self.imgs = imgs[:int(data_len)]
        else:
            self.imgs = imgs
        self.tfs = transforms.Compose([
                transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        self.tfs_p = transforms.Compose([
                transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        self.is_train = is_train
        self.mask_config = mask_config
        self.mask_mode = self.mask_config['mask_mode']
        self.image_size = image_size

    def __getitem__(self, index):
        ret = {}
        path = self.imgs[index]
        img = Image.open(path).convert("RGB").resize(self.image_size, resample=Image.Resampling.BICUBIC)
        mask = self.get_mask()
        
        if self.is_train:
            img = data_aug(img)
        img = self.tfs(img)
        
        ret['image'] = img
        ret['con_image'] = img*(1-mask) + torch.randn_like(img)*mask
        ret['mask_image'] = img*(1-mask) + mask
        ret['mask'] = mask
        ret['name'] = "".join(i for i in path.split("/")[-1].split(".")[:-1])
        return ret

    def __len__(self):
        return len(self.imgs)

    def get_mask(self):
        if self.mask_mode == 'manual':
            mask = bbox2mask(self.image_size, self.mask_config['shape'])
        elif self.mask_mode == 'fourdirection' or self.mask_mode == 'onedirection':
            mask = bbox2mask(self.image_size, random_cropping_bbox(self.image_size, mask_mode=self.mask_mode))
        elif self.mask_mode == 'hybrid':
            if np.random.randint(0,2)<1:
                mask = bbox2mask(self.image_size, random_cropping_bbox(self.image_size, mask_mode='onedirection'))
            else:
                mask = bbox2mask(self.image_size, random_cropping_bbox(self.image_size, mask_mode='fourdirection'))
        elif self.mask_mode == 'file':
            pass
        else:
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.')
        return torch.from_numpy(mask).permute(2,0,1)
    
class ColorizationDataset(data.Dataset):
    def __init__(self, data_root, data_len=-1, image_size=[256, 256], is_train=False):
        imgs = make_dataset(data_root)
        if data_len > 0:
            self.imgs = sorted(imgs)[:int(data_len)]
        else:
            self.imgs = imgs
            
        self.tfs = transforms.Compose([
                # transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        self.tfs_g = transforms.Compose([
                # transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        self.data_root = data_root
        self.image_size = image_size
        self.is_train = is_train

    def __getitem__(self, index):
        ret = {}
        path = self.imgs[index]
        img = Image.open(path).convert("RGB").resize(self.image_size)
        # w, h = img.size
        # if self.image_size[0] < min(w, h):
        #     left = random.randrange(0, w-self.image_size[0])
        #     up = random.randrange(0, h-self.image_size[1])
        #     img = img.crop((left, up, left+self.image_size[0], up+self.image_size[1]))

        if self.is_train and random.random() > 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            
        if self.is_train and random.random() > 0.5:
            img = img.transpose(Image.FLIP_TOP_BOTTOM)
            
        # if self.is_train:
        #     img = data_aug(img)
        gray = img.convert("L")
        img = self.tfs(img)
        gray = self.tfs_g(gray)
        
        ret['gt'] = img
        ret['lq'] = gray
        ret['fname'] = os.path.basename(path)
        return ret

    def __len__(self):
        return len(self.imgs)
    
class GoproDataset(data.Dataset):
    def __init__(self, data_root, blur_type="/blur/", data_len=-1, image_size=[256, 256], is_train=False):
        imgs = make_dataset(data_root, cond="/sharp/")
        if data_len > 0:
            self.imgs = sorted(imgs)[:int(data_len)]
        else:
            self.imgs = imgs
            
        self.tfs = transforms.Compose([
                # transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        
        self.data_root = data_root
        self.image_size = image_size
        self.blur_type = blur_type
        self.is_train = is_train

    def __getitem__(self, index):
        ret = {}
        path = self.imgs[index]
        img = Image.open(path).convert("RGB")
        blur = Image.open(path.replace("/sharp/", self.blur_type)).convert("RGB")
        w, h = img.size
        if min(self.image_size) <= min(w, h):
            left = 0 if self.image_size[0] == w else random.randrange(0, w-self.image_size[0])
            up = 0 if self.image_size[1] == h else random.randrange(0, h-self.image_size[1])
            img = img.crop((left, up, left+self.image_size[0], up+self.image_size[1]))
            blur = blur.crop((left, up, left+self.image_size[0], up+self.image_size[1]))
        
        
        if self.is_train and random.random() > 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            blur =blur.transpose(Image.FLIP_LEFT_RIGHT)
            
        if self.is_train and random.random() > 0.5:
            img = img.transpose(Image.FLIP_TOP_BOTTOM)
            blur = blur.transpose(Image.FLIP_TOP_BOTTOM)
            
        # if self.is_train:
        #     img = data_aug(img)
        
        img = self.tfs(img)
        blur = self.tfs(blur)
        
        ret['gt'] = img
        ret['lq'] = blur
        ret['fname'] = os.path.basename(path)
        return ret

    def __len__(self):
        return len(self.imgs)

class IDMsr(data.Dataset):
    def __init__(self, data_root, data_len=-1, lr_size= [32, 32], hr_size=[256, 256], is_train=False):
        imgs = make_dataset(data_root)
        if data_len > 0:
            self.imgs = imgs[:int(data_len)]
        else:
            self.imgs = imgs
        self.tfs = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        
        self.is_train = is_train
        self.lr_size = lr_size
        self.hr_size = hr_size

    def __getitem__(self, index):
        ret = {}
        path = self.imgs[index]
        img = Image.open(path).convert("RGB")
        w, h = img.size
        if self.hr_size[0] < min(w, h):
            left = random.randrange(0, w-self.hr_size[0])
            up = random.randrange(0, h-self.hr_size[1])
            hr_img = img.crop((left, up, left+self.hr_size[0], up+self.hr_size[1]))
        else:
            hr_img = img
        
        if self.is_train and random.random() > 0.5:
            hr_img = hr_img.transpose(Image.FLIP_LEFT_RIGHT)
        
        lr_img = hr_img.resize(self.lr_size, resample=Image.Resampling.BICUBIC)
        
        ret['lq'] = self.tfs(lr_img)
        ret['gt'] = self.tfs(hr_img)
        ret['fname'] = os.path.basename(path)
        return ret

    def __len__(self):
        return len(self.imgs)
    
class SRtest(data.Dataset):
    def __init__(self, data_root, data_len=-1, lq_path= "/bicubic_x4", is_train=False):
        imgs = make_dataset(data_root)
        if data_len > 0:
            self.imgs = imgs[:int(data_len)]
        else:
            self.imgs = imgs
        self.tfs = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        
        self.is_train = is_train
        self.lq_path = lq_path

    def __getitem__(self, index):
        ret = {}
        path = self.imgs[index]
        img = Image.open(path).convert("RGB")
        lq = Image.open(path.replace("/hr", self.lq_path)).convert("RGB")
       
        ret['lq'] = self.tfs(lq)
        ret["gt"] = self.tfs(img)
        ret['fname'] = os.path.basename(path)
        return ret

    def __len__(self):
        return len(self.imgs)

class MultiSR(data.Dataset):
    def __init__(self, data_root, mask_config={},data_len=-1, image_size=[512, 512], mask_size=[128, 128], is_train=False):
        self.data_root = data_root
        imgs = make_dataset(data_root)
        if data_len > 0:
            self.imgs = sorted(imgs)[:int(data_len)]
        else:
            self.imgs = imgs
        self.tfs = transforms.Compose([
                # transforms.Resize((image_size[0], image_size[1])),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])
        
        self.image_size = image_size
        self.mask_config = mask_config
        self.mask_mode = self.mask_config['mask_mode']
        self.mask_root = self.mask_config["root"]
        self.image_size = image_size
        self.mask_size = mask_size

    def __getitem__(self, index):
        ret = {}
        
        path = self.imgs[index]
        img = Image.open(path).convert("RGB")
        w, h = img.size
        left = random.randrange(0, w-self.image_size[0])
        up = random.randrange(0, h-self.image_size[1])
        img = img.crop((left, up, left+self.image_size[0], up+self.image_size[1]))
        img1 = self.tfs(img.resize((self.image_size[0] // 8, self.image_size[1] // 8), resample=Image.BILINEAR))
        img2 = self.tfs(img.resize((self.image_size[0] // 4, self.image_size[1] // 4), resample=Image.BILINEAR))
        img3 = self.tfs(img.resize((self.image_size[0] // 2, self.image_size[1] // 2), resample=Image.BILINEAR))
        mask2 = self.get_mask(index)
        mask1 = transforms.Resize((self.mask_size[0] // 2, self.mask_size[1] // 2))(mask2)
        mask3 = transforms.Resize((self.mask_size[0] * 2, self.mask_size[1] * 2))(mask2)
        mask4 = transforms.Resize((self.mask_size[0] * 4, self.mask_size[1] * 4))(mask2)
        mask1[mask1>0] = 1
        mask3[mask3>0] = 1
        mask4[mask4>0] = 1

        ret['image'] = img1
        ret['stg1'] = torch.cat([img1 * (1-mask1), mask1-0.5], dim=0)
        ret['stg2'] = torch.cat([img2 * (1-mask2), mask2-0.5], dim=0)
        ret['stg3'] = torch.cat([img3 * (1-mask3), mask3-0.5], dim=0)
        ret['gt'] = self.tfs(img)
        ret['mask'] = mask4
        ret['name'] = "".join(i for i in path.split("/")[-1].split(".")[:-1])
        return ret

    def __len__(self):
        return len(self.imgs)

    def get_mask(self, index=0):
        if self.mask_mode == 'manual':
            mask = torch.from_numpy(np.array([np.asarray(Image.open(self.mask_root+f"{index % 400}.png"))]))
            return mask
        elif self.mask_mode == 'bbox':
            mask = bbox2mask(self.mask_size, random_bbox())
        elif self.mask_mode == 'center':
            h, w = self.image_size
            mask = bbox2mask(self.mask_size, (h//4, w//4, h//2, w//2))
        elif self.mask_mode == 'irregular':
            mask = get_irregular_mask(self.mask_size)
        elif self.mask_mode == 'free_form':
            mask = brush_stroke_mask(self.mask_size)
        elif self.mask_mode == 'hybrid':
            regular_mask = bbox2mask(self.mask_size, random_bbox(self.mask_size, 
                                    (self.mask_size[0] // 2, self.mask_size[1] // 2)))
            irregular_mask = brush_stroke_mask(self.mask_size, )
            mask = regular_mask | irregular_mask
        elif self.mask_mode == 'file':
            pass
        else:
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.')
        return torch.from_numpy(mask).permute(2,0,1)
