import os
from PIL import Image
from typing import Callable, Optional, Tuple, Any, List
import pathlib
from distutils.dir_util import copy_tree

class WilddashDataset(object):
    """
    This class represents the semantic Wilddash dataset. Annotations are similar to Cityscapes annotation, therefore same
    id transformations can be used.

    Wilddash dataset is required to have following folder structure:
    wilddash/
            /gtFine
                /train
                    /city
                /val
                    /city
            /leftImg8bit
                /train
                    /city
                /val
                    /city

    """
    def __init__(self,
                 root,
                 split,
                 transforms: Optional[Callable] = None):

        super(WilddashDataset, self).__init__()

        self.mode = 'gtFine'
        self.images_dir = os.path.join(root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(root, self.mode, split)
        self.split = split
        self.images = []
        self.targets = []
        self.transforms = transforms

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_name = file_name.split("_leftImg8bit.jpg")[0] + "_gtFine_labelIds.png"
                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(os.path.join(target_dir, target_name))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.images)


def wilddash(root: str,
        split: str,
        transforms: List[Callable]):
    return WilddashDataset(root=root,
                            split=split,
                            transforms=transforms)

def create_wilddash_dataset(opts):
    root = opts.root
    out_dir = os.path.join(os.path.dirname(opts.root), "wilddash")
    labels_root = os.path.join(root, "labels")
    out_labels_root = os.path.join(out_dir, "gtFine", "val", "city")
    pathlib.Path(out_labels_root).mkdir(parents=True, exist_ok=True)
    imgs_root = os.path.join(root, "images")
    out_images_root = os.path.join(out_dir, "leftImg8bit", "val", "city")
    pathlib.Path(out_images_root).mkdir(parents=True, exist_ok=True)
    copy_tree(labels_root, out_labels_root)
    copy_tree(imgs_root, out_images_root)

    # Delete images
    img_names = os.listdir(out_images_root)
    for i, img_name in enumerate(img_names):
        label_name = img_name.replace(".jpg", ".png")
        img_path = os.path.join(out_images_root, img_name)
        label_path = os.path.join(out_labels_root, label_name)
        out_pre = f"city_{0:05n}_{i:05n}"
        new_img_path = img_path.replace(img_name, f"{out_pre}_leftImg8bit.jpg")
        new_label_path = label_path.replace(label_name, f"{out_pre}_gtFine_labelIds.png")
        if i < opts.num_images:
            os.rename(label_path, new_label_path)
            os.rename(img_path, new_img_path)
        else:
            # Delete images
            os.remove(img_path)
            os.remove(label_path)