from PIL import Image
from typing import Optional, Callable, Tuple, Any, List
from collections import namedtuple
from tqdm import tqdm
import random
import imageio
import pathlib
import os
import shutil
import numpy as np
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


class SynthiaDataset(object):

    SynthiaClass = namedtuple(
        "SynthiaClass",
        ["name", "id", "train_id", "ignore_in_eval", "color"]
    )

    classes = [
        SynthiaClass("road",            3, 0,   False,  (128, 64, 128)),
        SynthiaClass("sidewalk",        4, 1,   False,  (244, 35, 232)),
        SynthiaClass("building",        2, 2,   False,  (70, 70, 70)),
        SynthiaClass("wall",            21, 3,  False,  (102, 102, 156)),
        SynthiaClass("fence",           5, 4,   False,  (64, 64, 128)),
        SynthiaClass("pole",            7, 5,   False,  (153, 153, 153)),
        SynthiaClass("traffic light",   15, 6,  False,  (250, 170, 30)),
        SynthiaClass("traffic sign",    9, 7,   False,  (220, 220, 0)),
        SynthiaClass("vegetation",      6, 8,   False,  (107, 142, 35)),
        SynthiaClass("sky",             1, 9,  False,  (70, 130, 180)),
        SynthiaClass("pedestrian",      10, 10, False,  (220, 20, 60)),
        SynthiaClass("rider",           17, 11, False,  (255, 0, 0)),
        SynthiaClass("car",             8, 12,  False,  (0, 0, 142)),
        SynthiaClass("bus",             19, 13, False,  (0, 60, 100)),
        SynthiaClass("motorcycle",      12, 14, False,  (0, 0, 230)),
        SynthiaClass("bicycle",         11, 15, False,  (119, 11, 32)),
        SynthiaClass("void",            0, 255, True,   (0, 0, 0)),
        SynthiaClass("parking slot",    13, 255, True,  (250, 170, 160)),
        SynthiaClass("road-work",       14, 255, True,  (128, 64, 64)),
        SynthiaClass("lanemarking",     22, 255, True,  (102, 102, 156))

    ]

    def __init__(self,
                 root,
                 split,
                 transforms: Optional[Callable] = None):

        super(SynthiaDataset, self).__init__()

        self.mode = 'gtFine'
        self.images_dir = os.path.join(root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(root, self.mode, split)
        self.split = split
        self.images = []
        self.targets = []
        self.transforms = transforms

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_id = '{}'.format(file_name.split('_leftImg8bit')[0])
                target_suffix = "_gtFine_labelIds"
                target_ext = ".png"
                target_name = target_id + target_suffix + target_ext

                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(os.path.join(target_dir, target_name))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """

        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])

        # Transform target back to PIL.Image to be able to apply transforms
        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.images)

def synthia(root: str,
            split: str,
            transforms: List[Callable]):
    return SynthiaDataset(root=root,
                          split=split,
                          transforms=transforms)

def split_synthia(args):
    '''
        Split SYNTHIA dataset into train, val and test set

        Current Synthia directory structure:
        RAND_CITYSCAPES/
            /GT
                /COLOR
                /LABELS
            /RGB

        Build directory similar to Cityscapes dataset:

        cityscapes/"
            /gtFine
                /train
                    /city
                /test
                    /city
                /val
                    /city
            /leftImg8bit
                /train
                    /city
                /test
                    /city
                /val
                    /city
    '''

    # Create split arrays
    entire_img_list = os.listdir(args.img_dir)
    entire_img_list = sorted(entire_img_list)

    # Add target_suffix to allow use of torchvision.Cityscapes Dataloader
    target_suffix = '_gtFine_labelIds'

    # Add filename part for split according to torchvision.Cityscapes Dataloader
    source_suffix = '_leftImg8bit'

    img_id = 0
    for img_name in tqdm(entire_img_list):
        # Load randomly selected image and save it in train directory
        img = Image.open(os.path.join(args.img_dir, img_name))

        # Get image id from path name
        name_id = os.path.splitext(img_name)[0]
        
        # Path to output directory
        path_city = os.path.join(args.img_out_dir, args.split_dir, 'city')
        pathlib.Path(path_city).mkdir(parents=True, exist_ok=True)

        # Save image in desired output
        img.save(
            os.path.join(path_city,
                         'synthia' + '_' + str(0).zfill(5) + '_' + str(img_id).zfill(5) + source_suffix + '.png')
        )

        # Get corresponding label
        if args.split_dir == "val":
            label_txt = np.loadtxt(os.path.join(args.label_dir, name_id + ".txt"))
            label = Image.fromarray(np.uint8(label_txt))
        else:
            raw_label_imageio = imageio.imread(os.path.join(args.label_dir, img_name), format='PNG-FI')
            # Extract semantic ids stored in 2nd channel
            label = Image.fromarray(np.uint8(raw_label_imageio[:, :, 0]))
        
        # Path to output directory
        path_city_label = os.path.join(args.label_out_dir, args.split_dir, 'city')
        pathlib.Path(path_city_label).mkdir(parents=True, exist_ok=True)

        # Save label into desired output folder
        label.save(
            os.path.join(path_city_label,
                         'synthia' + '_' + str(0).zfill(5) + '_' + str(img_id).zfill(5) + target_suffix + '.png')
        )

        # Update image id
        img_id += 1
        if args.split_dir == "val" and img_id == 6500:
            break

def create_val_set(args):
    # Get all images from Synthia RAND_CITYSCAPES
    source_path = os.path.join(args.img_out_dir, "train", "city")
    synthia_list = sorted(os.listdir(source_path))

    # Randomly select 400 images for validation
    val_samples = random.sample(synthia_list, k=400)

    for image_name in val_samples:
        # Path to output directory
        img_val_path = os.path.join(args.img_out_dir, "val", 'city')
        lbl_val_path = os.path.join(args.label_out_dir, "val", 'city')
        pathlib.Path(img_val_path).mkdir(parents=True, exist_ok=True)
        pathlib.Path(lbl_val_path).mkdir(parents=True, exist_ok=True)

        # Move corresponding label
        label_name = '{}'.format(image_name.split('_leftImg8bit')[0])
        target_name = label_name + "_gtFine_labelIds" + ".png"

        # Move files
        shutil.move(src=os.path.join(source_path, image_name), dst=os.path.join(img_val_path, image_name))
        shutil.move(src=os.path.join(args.label_out_dir, "train", "city", target_name),
                    dst=os.path.join(lbl_val_path, target_name))
