import os
from typing import Dict

__all__ = ["get_class2id_map", "get_id2class_map", "get_n_classes"]


def get_class2id_map(dataset: str, dataset_dir: str = "./dataset") -> Dict[str, int]:
    """
    Args:
        dataset: 
        dataset_dir: the path to the datset directory
    """


    with open(os.path.join(dataset_dir, "{}/mapping.txt".format(dataset)), "r") as f:
        actions = f.read().split("\n")[:-1]

    class2id_map = dict()
    for a in actions:
        class2id_map[a.split()[1]] = int(a.split()[0])

    return class2id_map


def get_id2class_map(dataset: str, dataset_dir: str = "./dataset") -> Dict[int, str]:
    class2id_map = get_class2id_map(dataset, dataset_dir)

    return {val: key for key, val in class2id_map.items()}


def get_n_classes(dataset: str, dataset_dir: str = "./dataset") -> int:
    return len(get_class2id_map(dataset, dataset_dir))
