import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import sys
sys.path.append('./')
# import INN.config as c
from natsort import natsorted


def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image


class Hinet_Dataset(Dataset):
    def __init__(self, config, transforms_=None, mode="train"):

        self.transform = transforms_
        self.mode = mode
        if mode == 'train':
            # train
            self.files = natsorted(sorted(glob.glob(config.TRAIN_PATH + "/*." + config.format_train)))
        else:
            # test
            self.files = sorted(glob.glob(config.VAL_PATH + "/*." + config.format_val))

    def __getitem__(self, index):
        try:
            image = Image.open(self.files[index])
            image = to_rgb(image)
            item = self.transform(image)
            return item

        except:
            return self.__getitem__(index + 1)

    def __len__(self):
        if self.mode == 'shuffle':
            return max(len(self.files_cover), len(self.files_secret))

        else:
            return len(self.files)




def get_data_loaders(config, batch_size,  cropsize, mode):
    

    if mode == 'train':
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomCrop(cropsize),
            T.ToTensor()
        ])
        dataloader = DataLoader(
            Hinet_Dataset(config, transforms_=transform, mode=mode),
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=8,
            drop_last=True
    )
    else:

        transform = T.Compose([
            T.Resize((cropsize,cropsize)),
            T.ToTensor(),
        ])
    

        dataloader = DataLoader(
            Hinet_Dataset(config, transforms_=transform, mode=mode),
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=1,
            drop_last=False
        )
    
    return dataloader 
