import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import glob
import random
import os
import torch
import numpy as np
import datasets_main.dataset_voc as dataset_voc
import datasets_main.dataset_camvid as dataset_camvid
import datasets_main.dataset_suim as dataset_suim
import datasets_main.dataset_idd as dataset_idd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms.functional as tfunc
import functools
from albumentations.pytorch.transforms import ToTensorV2
# from albumentations.augmentation.transforms import Normalize
import datasets_main.utils as utils
import albumentations as A


from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

#######################################################
#               Define Dataset Class
#######################################################


        



def load_dataset(params):
    """Loads dataset and returns corresponding data loader."""

    if params.dataset_name == "suim":
        root = '/home/semanticSegmentation/datasets/suim/train_val/'
        dataset_main = dataset_suim.SUIM(datapath=root,
               transform_img= A.Compose([
                   A.PadIfNeeded(min_height=params.resize, min_width=params.resize, p=1),
                   A.Resize(height=params.resize, width=params.resize,p=1),
                                      A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                                      ToTensorV2()
                                      ]), 

                transform_mask= A.Compose([
                   A.PadIfNeeded(min_height=params.resize, min_width=params.resize, p=1),
                   A.Resize(height=params.resize, width=params.resize,p=1),
                   ToTensorV2()]),

               image_size=params.resize)

        train_data, test_data = torch.utils.data.random_split(dataset_main, [1220, 305], generator=torch.Generator().manual_seed(42))
                                        


        train_data_loader = DataLoader(train_data, batch_size=params.batch_size,
                          shuffle=True, num_workers=0,
                          drop_last=True)
        val_data_loader = DataLoader(test_data, batch_size=params.batch_size,
                          shuffle=False, num_workers=0,
                          drop_last=False)

        return train_data_loader, val_data_loader
    
    if params.dataset_name == "camvid":
        root = "/home/semanticSegmentation/datasets/camvid/CamVid"
        transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((params.resize, params.resize)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
        ])

        transform_img_label = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((params.resize, params.resize))
        ])

        traindataset = dataset_camvid.trainset(params, transform_img, root +
                                "/train/", root + "/train_labels/", transform_img_label)
        validdataset = dataset_camvid.trainset(params, transform_img, root +
                                "/val/", root + "/val_labels/", transform_img_label)
        
        train_loader = DataLoader(
            traindataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)
        valid_loader = DataLoader(
            validdataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)

        return train_loader, valid_loader

    if params.dataset_name == "idd":
        root = "/home/semanticSegmentation/datasets/IDD/my_dataset_segmentation"
        transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((params.resize, params.resize)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])

        transform_img_label = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((params.resize, params.resize))
        ])

        traindataset = dataset_idd.trainset(params, transform_img, root +
                                "/images/train/", root + "/mask/train/", transform_img_label)
        validdataset = dataset_idd.trainset(params, transform_img, root +
                                "/images/val/", root + "/mask/val/", transform_img_label)

        # print(len(traindataset))
        # print(len(validdataset))
        train_loader = DataLoader(
            traindataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)
        valid_loader = DataLoader(
            validdataset, batch_size=params.batch_size, shuffle=True,  num_workers=2)

        return train_loader, valid_loader
    
    if params.dataset_name == 'pascalvoc':
        root = '/home/semanticSegmentation/datasets/datasets/pascalvoc/VOCdevkit/'
        transform = transforms.Compose([
                transforms.Pad(10),
                transforms.ToTensor(),
                transforms.Resize((params.resize, params.resize)),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])

        train_data_set = dataset_voc.VOC(root=root,
                             image_size=(params.resize, params.resize),
                             dataset_type='train',
                             transform=transform)
        train_data_loader = DataLoader(train_data_set,
                                       batch_size=params.batch_size,
                                       shuffle=True)

        val_data_set = dataset_voc.VOC(root=root,
                           image_size=(params.resize, params.resize),
                           dataset_type='val',
                           transform=transform)
        val_data_loader = DataLoader(val_data_set,
                                     batch_size=params.batch_size,
                                     shuffle=False) # For make samples out of various models, shuffle=False

        return train_data_loader, val_data_loader

    if params.dataset_name == "cityscapes":
        root = '/home/semanticSegmentation/datasets/cityscapes'
        train_ds = datasets.Cityscapes(
            root, split='train', target_type='semantic', transform=None)
        val_ds = datasets.Cityscapes(
            root, split='val', target_type='semantic', transform=None)

        preprocess_fn_img = transforms.Compose([
            transforms.Resize((params.resize, params.resize)),
            transforms.ToTensor(),
            transforms.Resize((params.resize, params.resize)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        preprocess_fn_smnt = transforms.Compose([
            transforms.PILToTensor(),
            transforms.Resize((params.resize, params.resize))
        ])

        train_dataset_customize = CustomDataset(
            train_ds, params, preprocess_fn_img, preprocess_fn_smnt)
        valid_dataset_customize = CustomDataset(
            val_ds, params, preprocess_fn_img, preprocess_fn_smnt)

        train_loader = DataLoader(train_dataset_customize, batch_size=params.batch_size, num_workers=1, drop_last = True)
        valid_loader = DataLoader(valid_dataset_customize, batch_size=1)

        return train_loader, valid_loader

    if params.dataset_name == "bdd100k":
        train_ds = []
        val_ds = []
        root = "/home/semanticSegmentation/datasets/bdd100k/seg"
        train_image_directory = root + "/images/train/"
        train_annotations_directory = root + "/labels2/train/"
        print(train_image_directory)

        for i in os.listdir(train_image_directory):
            train_image = Image.open(train_image_directory + i)
            train_annotation = Image.open(train_annotations_directory + i.split(".")[0] + "_train_id.png")
            train_ds.append((train_image, train_annotation))

        val_image_directory = root + "/images/val/"
        val_annotations_directory = root + "/labels2/val/"
        print(val_image_directory)

        for i in os.listdir(val_image_directory):
            val_image = Image.open(val_image_directory + i)
            val_annotation = Image.open(val_annotations_directory + i.split(".")[0] + "_train_id.png")
            val_ds.append((val_image, val_annotation))


        preprocess_fn_img = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((params.resize, params.resize)),
            utils.make_3channels(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        preprocess_fn_smnt = transforms.Compose([
            transforms.PILToTensor(),
            transforms.Resize((params.resize, params.resize)),
            
        ])

        train_dataset_customize = CustomDataset(
            train_ds, params, preprocess_fn_img, preprocess_fn_smnt)
        valid_dataset_customize = CustomDataset(
            val_ds, params, preprocess_fn_img, preprocess_fn_smnt)

        train_loader = DataLoader(
            train_dataset_customize, batch_size=params.batch_size, num_workers=1)
        valid_loader = DataLoader(valid_dataset_customize, batch_size=params.batch_size)

        return train_loader, valid_loader


class CustomDataset(Dataset):
    def __init__(self, train_ds, params, img_transforms=None, smnt_transforms=None):
        self.dataset = train_ds
        self.img_transforms = img_transforms
        self.smnt_transforms = smnt_transforms
        self.params = params

    def __getitem__(self, index):
        img, smnt = self.dataset[index]

        if self.img_transforms is not None:
            img = self.img_transforms(img)

        if self.smnt_transforms is not None:
            smnt = self.smnt_transforms(smnt)
        return img, smnt

    def __len__(self):
        return len(self.dataset)


