import time
import os

import torch


def load_d_data_names(datasets):
    root = '/mnt/data01/public/aad_data'
    train_names = []
    for dataset in datasets:
        file_names = os.listdir(root + '/' + dataset)
        file_names = ['.'.join(name.split('.')[:-1]) for name in file_names]
        train_names += file_names
    return train_names


if __name__ == "__main__":
    import ujson
    names = load_d_data_names(['cifar10'])
    cdist_path = '/mnt/data01/public/aad_data'
    for d_name in names:
        start = time.perf_counter()
        dataset_name = d_name.split('-')[0]
        clip_save = torch.load(cdist_path + '/' + f'{dataset_name}/{d_name}.tar', pickle_module=ujson)
        # clip_save =
        cdist, y, _ = clip_save[:3]
        end = time.perf_counter() - start
        print(f'process time: {end} s')
