from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from natsort import natsorted
import random
from torch.utils.data import Subset
from util.utils import *
import util.config as c
import numpy as np

class Dataset(Dataset):
    def __init__(self, transforms_=None, mode="train"):

        self.transform = transforms_
        self.TRAIN_PATH = c.TRAIN_PATH
        self.TEST_PATH = c.TEST_PATH 
        
        self.format_train = 'png' 
        self.mode = mode

        if mode == 'train':
            # train
            self.files = natsorted(sorted(imglist(self.TRAIN_PATH, self.format_train)))          
        else:
            # test
            self.files =natsorted( sorted((imglist(self.TEST_PATH, self.format_train))))
            
    def __getitem__(self, index):
        try:
            image = Image.open(self.files[index])
            image = to_rgb(image)
            item = self.transform(image)
            
            return item
        except:
            return self.__getitem__(index+1)

    def __len__(self):   
        return len(self.files)

transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomCrop(c.cropsize),
    T.ToTensor()
])

transform_val = T.Compose([
    T.CenterCrop(c.cropsize_val), 
    T.ToTensor(),
])

trainloader = DataLoader(
    Dataset(transforms_=transform , mode="train"),
    # batch_size=c.batch_size,  #Train_First
    batch_size=c.batch_size_2,    #Train_Second
    shuffle=True,  #True,
    pin_memory=True,
    num_workers=4, 
    drop_last=True
)

testloader = DataLoader(
    Dataset(transforms_=transform_val, mode="val"),
    # batch_size=c.batchsize_val,  #Train_First
    batch_size=c.batch_size_2 ,   #Deployment
    shuffle=False,
    pin_memory=True,
    num_workers=4,
    drop_last=True
)