import numpy as np
import os
from os import listdir
from os.path import isfile, join
import shutil


def file_delete(path):
    """
    :param path: path where data locate
    :return: the files in train/val folders are deleted
    """
    train_path = join(path, 'train')
    valid_path = join(path, 'val')

    train_fils = listdir(train_path)
    valid_files = listdir(valid_path)

    # delete the folders under this directory
    for item in train_fils:
        item_path = join(train_path, item)
        shutil.rmtree(item_path, True)
    print("training set folders are deleted successfully")

    for item in valid_files:
        item_path = join(valid_path, item)
        shutil.rmtree(item_path, True)
    print("valid set folders are deleted successfully")

    return None


def file_add(path, class_list):
    """
    :param path: path to add the selected classes
    :param class_list: the name of the selected classes
    :return: add these classes to the path
    """
    # data source address
    imagenet_all_train_path = '../data/imagenet/train'
    imagenet_all_val_path = '../data/imagenet/val'

    # data target address
    train_path = join(path, 'train')
    valid_path = join(path, 'val')

    for subclass in class_list:  # e.g., 'n01234'
        # copy training set
        print(f"Copying subclass {subclass} training set")
        src_class_train = r'{}/{}'.format(imagenet_all_train_path, subclass)
        dst_class_train = r'{}/{}'.format(train_path, subclass)
        shutil.copytree(src=src_class_train, dst=dst_class_train)

        # copy valid set
        print(f"Copying subclass {subclass} valid set")
        src_class_valid = r'{}/{}'.format(imagenet_all_val_path, subclass)
        dst_class_valid = r'{}/{}'.format(valid_path, subclass)
        shutil.copytree(src=src_class_valid, dst=dst_class_valid)
    print("subclasses are copied successfully")

    return None


def folder_rename(path, class_id):
    """
    :param path:
    :return:
    """
    name_candidate = ['n00000000', 'n00000001', 'n00000002']
    # data target address
    train_path = join(path, 'train')
    valid_path = join(path, 'val')
    for i, name in enumerate(class_id[:5]):
        os.rename(src=join(train_path, name), dst=join(train_path, name_candidate[i]))
        os.rename(src=join(valid_path, name), dst=join(valid_path, name_candidate[i]))
    print("Rename successfully")
    pass


if __name__ == '__main__':
    """
    The goal of this file is to create a sub-dataset of imagenet 
    that contains only 3 classes, for visualization use
    """
    path = '../data/imagenet_visualization'  #

    # target class for visualization
    CLASSES = ['miniature_poodle', 'standard_poodle', 'submarine']
    CLASSES_id = ['n02113712', 'n02113799', 'n04347754']

    # Step 1. clear the target path to make it void
    file_delete(path)

    # Step 2. copy the target 3 classes from the whole ImageNet directory
    file_add(path, class_list=CLASSES_id)

    # Step 3. rename the index of 3 classes, such that the semantic de-similar class is always the last one
    folder_rename(path, CLASSES_id)
