import numpy as np
import glob
import os
# from collections import OrderedDict


def get_dataset(data_dir, model_type, agg_method, known_anom=[1]):
    """

    Args:
        data_dir:
        model_type:
        agg_method:
        known_anom: list of int on anom id to use during training (id is based on alphabetical order of files)

    Returns:

    """
    # unique_files = dict()
    # class_dict = OrderedDict({""})
    normal_label = 0
    anom_label_id_curr = 0

    test_class_labels = ["Normal"]
    anoms = []
    test_class_ids = [normal_label]
    label_list_train = []
    img_emb_list_train = []
    label_list_test = []
    img_emb_list_test = []
    for file in glob.glob(f'{data_dir}/{model_type}-{agg_method}-*'):
        file_name = os.path.splitext(file)[0]
        file_info = file_name[:-4].split('-')[2:]
        if file_info[0] == "train":
            # all images from training dataset are normal
            img_emb_list_train.append(np.load(file))
            label_list_train.append(normal_label)
        #     file is from test
        else:
            file_type = "-".join(file_info[1:])
            if file_type == "good":
                img_emb_list_test.append(np.load(file))
                label_list_test.append(normal_label)
            else:
                # test time anomalies
                if file_type not in anoms:
                    # assume that files of the same type (i.e., train normal, test normal, anomaly types) are tgt
                    anom_label_id_curr += 1
                    anoms.append(file_type)
                    # test_class_labels.append(file_type)
                    # test_class_ids.append(anom_label_id_curr)
                if anom_label_id_curr in known_anom:
                    img_emb_list_train.append(np.load(file))
                    label_list_train.append(1 - normal_label)
                else:
                    img_emb_list_test.append(np.load(file))
                    label_list_test.append(anom_label_id_curr)
                    if file_type not in test_class_labels:
                        test_class_labels.append(file_type)
                        test_class_ids.append(anom_label_id_curr)
                        # anom_label_id_curr += 1



        # if file_type in unique_files:
        #     unique_files[file_type] += 1
        # else:
        #     unique_files[file_type] = 1
    #
    # # get datasets
    # data_classes = sorted(list(unique_files.keys()))
    # # print(data_classes)
    # if len(data_classes) <= 3:
    #     return None
    # for i in range(len(data_classes)):
    #     # get first type which is an anomaly -- use for training
    #     if not (data_classes[i] == "test-good" or data_classes[i] == "train-good"):
    #         break
    # train_dirs = [data_classes[-1], data_classes[i]]
    # test_dirs = data_classes[:-1]
    # del test_dirs[i]
    # test_dirs.remove("test-good")
    # anom_ids = list(np.arange(len(test_dirs)) + 1)
    # test_dirs = ['test-good'] + test_dirs
    # print("train:", train_dirs)
    # print("test:", test_dirs)
    #
    # # The first k categories are considered good, all other are defect
    # label_list = []
    # img_emb_list = []
    # #     file_names = []
    # i = 0
    # for i, name in enumerate(categories):
    #     # label normal as 0, anomalies as 1
    #     if i < k:
    #         label = 0
    #     else:
    #         label = 1
    #
    #     # Get all files that match pattern
    #     for file in glob.glob(f'{data_dir}/embeddings/{model_type}-{agg_method}-{name}-*.npy'):
    #         #             file_names.append(file)
    #         label_list.append(label)
    #         arr = np.load(file)
    #         img_emb_list.append(arr)

    # get CLS token
    X = np.array(img_emb_list_train)
    # 0/1 for normal/anom
    y = np.array(label_list_train)
    x_test = np.array(img_emb_list_test)
    # 0/1/2/3/... for anom id
    y_test_id = np.array(label_list_test)
    #     print(file_names)

    return X, y, x_test, y_test_id, test_class_labels, test_class_ids