import os
import glob
import numpy as np


def imagenet_val_files_and_labels(dataset_directory):
    classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines()
    class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))}

    images_path = os.path.join(dataset_directory, 'val')
    filenames = []
    labels = []
    lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines()
    for i, line in enumerate(lines):
        class_name = line.split('\n')[0]
        a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1)
        filenames.append(f'{images_path}/{a}')
        labels.append(class_to_indx[class_name])
        # print(filenames[-1], labels[-1])

    return filenames, labels


def _find_classes(dir):
    # Faster and available in Python 3.5 and above
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def get_cub_dataset_info(dataset_dir, subset='test', shuffle=False, random_seed=None):
    """
    The CUB_200_2011 dataset.
    Args:
        dataset_dir:
        shuffle:

    Returns:

    """
    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

    # read
    testset_dir = os.path.join(dataset_dir, subset)
    seg_dir = os.path.join(dataset_dir, 'segmentations')

    class_names, class_to_idx = _find_classes(testset_dir)
    # num_classes = len(class_names)
    image_paths = []
    seg_paths = []
    labels = []
    for class_name in sorted(class_names):
        classes_dir = os.path.join(testset_dir, class_name)
        for img_path in sorted(glob.glob(os.path.join(classes_dir, '*'))):
            if not img_path.lower().endswith(IMG_EXTENSIONS):
                continue

            image_paths.append(img_path)
            seg_paths.append(os.path.join(seg_dir, img_path.split(f'{subset}/')[-1].replace('jpg', 'png')))
            labels.append(class_to_idx[class_name])

            assert os.path.exists(seg_paths[-1]), seg_paths[-1]

    image_paths = np.array(image_paths)
    seg_paths = np.array(seg_paths)
    labels = np.array(labels)

    if shuffle:
        np.random.seed(random_seed)
        random_per = np.random.permutation(range(len(image_paths)))
        image_paths = image_paths[random_per]
        seg_paths = seg_paths[random_per]
        labels = labels[random_per]

    return image_paths, seg_paths, labels, len(class_names)


def get_pytorch_dataset_info(dataset_dir, subset='test', shuffle=False, random_seed=None):
    """
    The dataset should be arranged as required in Pytorch VisionDataset.
    Args:
        dataset_dir:
        shuffle:

    Returns:

    """
    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

    # read
    subset_dir = os.path.join(dataset_dir, subset)

    class_names, class_to_idx = _find_classes(subset_dir)
    # num_classes = len(class_names)
    image_paths = []
    labels = []
    for class_name in class_names:
        classes_dir = os.path.join(subset_dir, class_name)
        for img_path in glob.glob(os.path.join(classes_dir, '*')):
            if not img_path.lower().endswith(IMG_EXTENSIONS):
                continue

            image_paths.append(img_path)
            labels.append(class_to_idx[class_name])

    image_paths = np.array(image_paths)
    labels = np.array(labels)

    if shuffle:
        np.random.seed(random_seed)
        random_per = np.random.permutation(range(len(image_paths)))
        image_paths = image_paths[random_per]
        labels = labels[random_per]


    return image_paths, labels, len(class_names)