import os
import glob
import PIL.Image
from collections import namedtuple
import fnmatch
import numpy as np
from typing import List, Callable

from datasets.cityscapes import CityscapesDataset

def gta(root: str,
        split: str,
        transforms: List[Callable]):

    return CityscapesDataset(root=root,
                             split=split,
                             transforms=transforms)

def preprocess(dataset_dir: str):
    """
    Function for removing data samples with size differences between image and annotation
    Args:
        dataset_dir: path to GTA folder

    """

    # Create catalog of every GTA image in dataset directory
    dataset_split = ["train", "val", "test"]
    # Count deleted files
    count_del = 0

    for split in dataset_split:
        images = sorted(glob.glob(os.path.join(dataset_dir, "leftImg8bit", split, "city", "*_leftImg8bit.png")))
        labels = sorted(glob.glob(os.path.join(dataset_dir, "gtFine", split, "city", "*_gtFine_labelIds.png")))
        assert len(images) == len(labels), "Length of catalogs do not match!"

        for image, label in zip(images, labels):
            # Get image filename pattern
            file_pattern = get_file_info(image)

            # Assert that label corresponds to current image
            assert fnmatch.fnmatch(os.path.basename(label), file_pattern)

            # Load image and label
            img = PIL.Image.open(image)
            gt = PIL.Image.open(label)

            if img.size != gt.size:
                print(f"Found data sample pair with unmatching size. Deleting files with name: {file_pattern}.")
                # Delete mismatching data samples
                os.remove(path=image)
                os.remove(path=label)
                count_del += 1

    print(f"{count_del} images have been removed from the dataset")


def get_file_info(file: str):
    cs_file = namedtuple('csFile', ['city', 'sequenceNb', 'frameNb', 'type', 'type2', 'ext'])
    baseName = os.path.basename(file)
    parts = baseName.split('_')
    parts = parts[:-1] + parts[-1].split('.')
    cs_file = cs_file(*parts[:-1], type2="", ext=parts[-1])
    file_pattern = "{}_{}_{}*".format(cs_file.city, cs_file.sequenceNb, cs_file.frameNb)
    return file_pattern


def count(image_folder):
    # List for storing image sizes
    height_list = []
    width_list = []
    # Get images in folder
    img_list = glob.glob(os.path.join(image_folder, '*.png'))

    # Open image and append shape to list
    for img in img_list:
        image = PIL.Image.open(img)
        height_list.append(image.height)
        width_list.append(image.width)

    # Get unique sizes
    heights, h_counts = np.unique(height_list, return_counts=True)
    widths, w_counts = np.unique(width_list, return_counts=True)
    # Print sizes and counts
    for h, c in zip(heights, h_counts):
       print(f"height: {h}: counts:{c}")
    for w, c in zip(widths, w_counts):
       print(f"height: {w}: counts:{c}")