import os
from typing import Tuple, Any, List, Callable, Optional
from PIL import Image

class IDDDataset(object):

    def __init__(self,
                 root,
                 split,
                 transforms: Optional[Callable] = None):

        super(IDDDataset, self).__init__()

        self.mode = 'gtFine'
        self.images_dir = os.path.join(root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(root, self.mode, split)
        self.split = split
        self.images = []
        self.targets = []
        self.transforms = transforms

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_name = file_name.split("_leftImg8bit.png")[0] + "_gtFine_labelcsTrainIds.png"
                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(os.path.join(target_dir, target_name))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.images)

def idd(root: str,
        split: str,
        transforms: List[Callable]):
    return IDDDataset(root=root,
                      split=split,
                      transforms=transforms)