import os
from PIL import Image
from typing import Callable, Optional, Tuple, Any, List
import pathlib
from distutils.dir_util import copy_tree

class MapillaryDataset(object):
    """
    This class represents the semantic Mapillary dataset. Annotations are similar to Cityscapes annotation, therefore
    same id transformations can be used.

    Mapillary dataset is required to have following folder structure:
    mapillary/
            /gtFine
                /train
                    /city
                /val
                    /city
            /leftImg8bit
                /train
                    /city
                /val
                    /city

    """
    def __init__(self,
                 root,
                 split,
                 transforms: Optional[Callable] = None):

        super(MapillaryDataset, self).__init__()

        self.mode = 'gtFine'
        self.images_dir = os.path.join(root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(root, self.mode, split)
        self.split = split
        self.images = []
        self.targets = []
        self.transforms = transforms

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_name = file_name.split("_leftImg8bit.jpg")[0] + "_gtFine_labelIds.png"
                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(os.path.join(target_dir, target_name))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.images)


def mapillary(root: str,
        split: str,
        transforms: List[Callable]):
    return MapillaryDataset(root=root,
                            split=split,
                            transforms=transforms)

def create_mapillary_dataset(opts):
    out_dir = os.path.join(os.path.dirname(opts.root), "mapillary_new")
    val_root = os.path.join(opts.root, "training")
    labels_root = os.path.join(val_root, "v1.2", "labels")
    out_labels_root = os.path.join(out_dir, "gtFine", "train", "city")
    pathlib.Path(out_labels_root).mkdir(parents=True, exist_ok=True)
    imgs_root = os.path.join(val_root, "images")
    out_images_root = os.path.join(out_dir, "leftImg8bit", "train", "city")
    pathlib.Path(out_images_root).mkdir(parents=True, exist_ok=True)
    copy_tree(labels_root, out_labels_root)
    copy_tree(imgs_root, out_images_root)
    out_size = tuple(opts.out_size)

    # Delete images
    img_names = os.listdir(out_images_root)
    for i, img_name in enumerate(img_names):
        label_name = img_name.replace(".jpg", ".png")
        img_path = os.path.join(out_images_root, img_name)
        label_path = os.path.join(out_labels_root, label_name)
        out_pre = f"city_{0:05n}_{i:05n}"
        new_img_path = img_path.replace(img_name, f"{out_pre}_leftImg8bit.jpg")
        new_label_path = label_path.replace(label_name, f"{out_pre}_gtFine_labelIds.png")
        if i < opts.num_images:
            if opts.resize:
                img_new = Image.open(img_path).resize(out_size, Image.BILINEAR)
                label_new = Image.open(label_path).resize(out_size, Image.NEAREST)
                img_new.save(new_img_path)
                label_new.save(new_label_path)
                # Delete images
                os.remove(img_path)
                os.remove(label_path)
            else:
                os.rename(img_path, new_img_path)
                os.rename(label_path, new_label_path)
        else:
            # Delete images
            os.remove(img_path)
            os.remove(label_path)