import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import glob
import random
import os
import cv2
import numpy as np

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


#######################################################
#               Define Dataset Class
#######################################################


def load_img(params, folder):
  # an iter variable
    images = []
    for filename in os.listdir(folder):
        img = Image.open(os.path.join(folder, filename))
        images.append(img)
    return images


class trainset(Dataset):
    def __init__(self, params, transform=None, root_train=None, root_train_label=None, transform_label=None):
        self.train_img = load_img(params, root_train)
        self.transform = transform
        self.transform_label = transform_label
        self.train_label_img = load_img(params, root_train_label)

    def __len__(self):
        return len(self.train_img)

    def __getitem__(self, index):
        img = self.transform(self.train_img[index])
        label = self.transform_label(self.train_label_img[index])
        return img, label


def load_dataset(params):
    transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((params.resize, params.resize)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    transform_img_label = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((params.resize, params.resize))
    ])

    traindataset = trainset(params, transform_img, params.data_dir +
                            "/images/train/", params.data_dir + "/mask/train/", transform_img_label)
    validdataset = trainset(params, transform_img, params.data_dir +
                            "/images/val/", params.data_dir + "/mask/val/", transform_img_label)

    # print(len(traindataset))
    # print(len(validdataset))
    train_loader = DataLoader(
        traindataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)
    valid_loader = DataLoader(
        validdataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)

    return train_loader, valid_loader
