import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import glob
import random
import os
import cv2
import numpy as np

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image


#######################################################
#               Define Dataset Class
#######################################################


def load_img(params, folder):
  # an iter variable
    images = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder, filename))
        img = cv2.resize(img, (params.resize, params.resize),
                         interpolation=cv2.INTER_AREA)
        images.append(img)
    return images


def load_label_img(params, folder):
    img1 = []  # an empty list for the files
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder, filename))
        img = cv2.resize(img, (params.resize, params.resize),
                         interpolation=cv2.INTER_AREA)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img1.append(img)
    return img1


def colortogray(cn):
    cn = np.reshape(cn, (1, 1, 3))
    cn = cv2.cvtColor(cn, cv2.COLOR_BGR2GRAY)
    return cn


def class_pixel(params, label_img):
    colors = []
    colors.append(colortogray(np.array([64, 128, 64], dtype='uint8')))
    colors.append(colortogray(np.array([128, 0, 192], dtype='uint8')))
    colors.append(colortogray(np.array([192, 128, 0], dtype='uint8')))
    colors.append(colortogray(np.array([64, 128, 0], dtype='uint8')))
    colors.append(colortogray(np.array([0, 0, 128], dtype='uint8')))
    colors.append(colortogray(np.array([128, 0, 64], dtype='uint8')))
    colors.append(colortogray(np.array([192, 0, 64], dtype='uint8')))
    colors.append(colortogray(np.array([64, 128, 192], dtype='uint8')))
    colors.append(colortogray(np.array([128, 192, 192], dtype='uint8')))
    colors.append(colortogray(np.array([128, 64, 64], dtype='uint8')))
    colors.append(colortogray(np.array([192, 0, 128], dtype='uint8')))
    colors.append(colortogray(np.array([64, 0, 192], dtype='uint8')))
    colors.append(colortogray(np.array([64, 128, 128], dtype='uint8')))
    colors.append(colortogray(np.array([192, 0, 192], dtype='uint8')))
    colors.append(colortogray(np.array([64, 64, 128], dtype='uint8')))
    colors.append(colortogray(np.array([128, 192, 64], dtype='uint8')))
    colors.append(colortogray(np.array([0, 64, 64], dtype='uint8')))
    colors.append(colortogray(np.array([128, 64, 128], dtype='uint8')))
    colors.append(colortogray(np.array([192, 128, 128], dtype='uint8')))
    colors.append(colortogray(np.array([192, 0, 0], dtype='uint8')))
    colors.append(colortogray(np.array([128, 128, 192], dtype='uint8')))
    colors.append(colortogray(np.array([128, 128, 128], dtype='uint8')))
    colors.append(colortogray(np.array([192, 128, 64], dtype='uint8')))
    colors.append(colortogray(np.array([64, 0, 0], dtype='uint8')))
    colors.append(colortogray(np.array([64, 64, 0], dtype='uint8')))
    colors.append(colortogray(np.array([128, 64, 192], dtype='uint8')))
    colors.append(colortogray(np.array([0, 128, 128], dtype='uint8')))
    colors.append(colortogray(np.array([192, 128, 192], dtype='uint8')))
    colors.append(colortogray(np.array([64, 0, 64], dtype='uint8')))
    colors.append(colortogray(np.array([0, 192, 192], dtype='uint8')))
    colors.append(colortogray(np.array([0, 0, 0], dtype='uint8')))
    colors.append(colortogray(np.array([0, 192, 64], dtype='uint8')))

    class_pix = np.ones([params.resize, params.resize, 1], dtype=int)
    for index, c in enumerate(colors):
        class_pix[label_img == c] = index
    return class_pix


def label_img_list(params, img_list):
    images = []
    for image in img_list:
        images.append(class_pixel(params, image))
    return images


class trainset(Dataset):
    def __init__(self, params, transform=None, root_train=None, root_train_label=None, transform_label=None):
        self.train_img = load_img(params, root_train)
        self.transform = transform
        self.transform_label = transform_label
        self.train_label_img = label_img_list(params,
                                              load_label_img(params, root_train_label))

    def __len__(self):
        return len(self.train_img)

    def __getitem__(self, index):
        img = self.transform(self.train_img[index])
        label = self.transform_label(self.train_label_img[index])
        return img, label


def load_dataset(params):
    transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    transform_img_label = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((params.resize, params.resize))
    ])

    traindataset = trainset(params, transform_img, params.data_dir +
                            "/train/", params.data_dir + "/train_labels/", transform_img_label)
    validdataset = trainset(params, transform_img, params.data_dir +
                            "/val/", params.data_dir + "/val_labels/", transform_img_label)

    print(len(traindataset))
    print(len(validdataset))
    
    train_loader = DataLoader(
        traindataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)
    valid_loader = DataLoader(
        validdataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)

    return train_loader, valid_loader

    