import numpy as np
from torch.utils.data import Dataset
from os.path import join
import os
import torch
import random
import random
def augment_patch(lr):
    random_k = np.random.randint(4)
    random_flip = np.random.randint(2)
    
    lr = np.rot90(lr, k=random_k, axes=(0, 1))
    lr = np.flip(lr, axis=-2) if random_flip else lr
    
    
    return lr
def augment_patch_n2t(lr,hr):
    random_k = np.random.randint(4)
    random_flip = np.random.randint(2)
    
    lr = np.rot90(lr, k=random_k, axes=(0, 1))
    lr = np.flip(lr, axis=-2) if random_flip else lr
    hr = np.rot90(hr, k=random_k, axes=(0, 1))
    hr = np.flip(hr, axis=-2) if random_flip else lr    
    
    return lr,hr

def get_stratified_coords2D(coord_gen, box_size, shape):
    box_count_y = int(np.ceil(shape[0] / box_size))
    box_count_x = int(np.ceil(shape[1] / box_size))
    x_coords = []
    y_coords = []
    for i in range(box_count_y):
        for j in range(box_count_x):
            y, x = next(coord_gen)
            y = int(i * box_size + y)
            x = int(j * box_size + x)
            if (y < shape[0] and x < shape[1]):
                y_coords.append(y)
                x_coords.append(x)
    return (y_coords, x_coords)

def rand_float_coords2D(boxsize):
    while True:
        yield (np.random.rand() * boxsize, np.random.rand() * boxsize)


class TrainDataset_Supervised_N2T(Dataset):
    def __init__(self, dirname, dirname_hr, patch_size=64):
        super(TrainDataset_Supervised_N2T, self).__init__()
        self.dirname = dirname

        ## Train
        self.filenames = []
        self.filenames_hr = []
        filenames_hr = os.listdir(dirname_hr)
        filenames_hr.sort()
        for filenames in filenames_hr:
            self.filenames.append(join(dirname, filenames))
            self.filenames_hr.append(join(dirname_hr, filenames))

        self.filenames_hr.sort()
        self.filenames.sort()
        self.patch_size = patch_size

    def __getitem__(self, idxx):
        for i in range(150):
            try:
                idx = random.randint(0, len(self.filenames) - 2)
                lr = np.load(self.filenames[idx])
                hr = np.load(self.filenames_hr[idx])
                ih, iw, ic = lr.shape
                if ih > self.patch_size and iw > self.patch_size and ic == 3:
                    ## Random Patch
                    ix = random.randrange(0, iw - self.patch_size + 1)
                    iy = random.randrange(0, ih - self.patch_size + 1)
                    lr = lr[iy:iy + self.patch_size, ix:ix + self.patch_size, :]
                    hr = hr[iy:iy + self.patch_size, ix:ix + self.patch_size, :]

                    ## Data Augmentation
                    lr, hr = augment_patch_n2t(lr, hr)
                    ##
                    lr = np.ascontiguousarray(lr.transpose((2, 0, 1)))
                    lr = torch.from_numpy(lr).float()
                    hr = np.ascontiguousarray(hr.transpose((2, 0, 1)))
                    hr = torch.from_numpy(hr).float()
                    return lr, hr
            except:
                error = 1

    def __len__(self):
        return len(self.filenames)

class TrainDataset_Supervised(Dataset):
    def __init__(self, dirname,patch_size = 64):
        super(TrainDataset_Supervised, self).__init__()
        self.dirname = dirname
        
        ## Train
        self.filenames = []
        filenames_lr = os.listdir(dirname)
        filenames_lr.sort()
        for filenames in filenames_lr:
            self.filenames.append(join(dirname,filenames))
       
        self.filenames.sort()
        self.patch_size = patch_size

    def __getitem__(self, idxx):
        for i in range(150):
            try:
                idx = random.randint(0,len(self.filenames)-2)
                lr = np.load(self.filenames[idx])
                ih, iw, ic = lr.shape
                if ih > self.patch_size and iw > self.patch_size and ic ==3:
                    ## Random Patch
                    ix = random.randrange(0, iw - self.patch_size + 1)
                    iy = random.randrange(0, ih - self.patch_size + 1)        
                    lr = lr[iy:iy+self.patch_size,ix:ix+self.patch_size,:]

                    ## Data Augmentation
                    lr = augment_patch(lr)
                    ##
                    lr = np.ascontiguousarray(lr.transpose((2, 0, 1)))
                    lr = torch.from_numpy(lr).float()
                    return lr
            except:
                error = 1
    def __len__(self):
        return len(self.filenames)
        
class TrainDataset_Supervised_n2same(Dataset):
    def __init__(self, dirname,patch_size = 64):
        super(TrainDataset_Supervised_n2same, self).__init__()
        self.dirname = dirname
        
        ## Train
        self.filenames = []
        filenames_lr = os.listdir(dirname)
        filenames_lr.sort()
        for filenames in filenames_lr[:20000]:
            self.filenames.append(join(dirname,filenames))
       
        self.filenames.sort()
        self.patch_size = patch_size
        self.mask_perc = 0.5

    def __getitem__(self, idxx):
        for i in range(150):
            try:
                idx = random.randint(0,len(self.filenames)-2)
                lr = np.load(self.filenames[idx])
                ih, iw, ic = lr.shape
                if ih > self.patch_size and iw > self.patch_size and ic ==3:
                    # Normalization 
                    lr_mean = lr.mean(axis=(0,1), keepdims=True)
                    lr_std = lr.std(axis=(0,1), keepdims=True)        
                    lr = (lr - lr_mean)/lr_std
                    ## Random Patch Extraction
        
                    ix = random.randrange(0, iw - self.patch_size + 1)
                    iy = random.randrange(0, ih - self.patch_size + 1)        
                    lr = lr[iy:iy+self.patch_size,ix:ix+self.patch_size,:]

                    ## Data Augmentation
                    lr = augment_patch(lr)
                    
                    ## Mask Generation
                    mask = np.zeros_like(lr)
                    patch_size = (self.patch_size, self.patch_size)
                    for c in range(3):
                        boxsize = np.round(np.sqrt(100/self.mask_perc)).astype(np.int)
                        maskcoords = get_stratified_coords2D(rand_float_coords2D(boxsize), 
                                                           box_size=boxsize, shape=tuple(patch_size))
                        indexing = maskcoords + (c,)
                        mask[indexing] = 1.0
                    noise_patch = np.random.normal(0, 0.2, lr.shape)
                    ## Numpy to Pytorch
                    noise_patch = np.ascontiguousarray(noise_patch.transpose((2, 0, 1)))
                    mask = np.ascontiguousarray(mask.transpose((2, 0, 1)))
                    lr = np.ascontiguousarray(lr.transpose((2, 0, 1)))
                    lr = torch.from_numpy(lr).float()
                    noise_patch = torch.from_numpy(noise_patch).float()
                    mask = torch.from_numpy(mask).float()   
        
                    return lr, noise_patch, mask
            except:
                error = 1
    def __len__(self):
        return len(self.filenames)

    
class TestDataset(Dataset):
    def __init__(self, dirname):
        super(TestDataset, self).__init__()
        self.dirname = dirname
        
        ## Train
        self.filenames = []
        filenames_lr = os.listdir(dirname)
        filenames_lr.sort()
        for filenames in filenames_lr:
            self.filenames.append(join(dirname,filenames))
        # Test     
        self.patch_size = 128
        self.filenames.sort()
        
    def __getitem__(self, idx):
        lr_hr = np.load(self.filenames[idx])
        hr = lr_hr[0,:,:,:]
        lr = lr_hr[1,:,:,:]
        ih, iw = lr.shape[:2]
        hr = np.ascontiguousarray(hr.transpose((2, 0, 1)))
        lr = np.ascontiguousarray(lr.transpose((2, 0, 1)))
        hr = torch.from_numpy(hr).float()
        lr = torch.from_numpy(lr).float()
        return lr, hr

    def __len__(self):
        return len(self.filenames)

