from data_loader import load_partition_data_landmarks

"""
    You can run with python check_download.py to check if you have all 
    data samples in federated_train.csv and test.csv.
"""

if __name__ == "__main__":
    data_dir = "./cache/images"
    fed_g23k_train_map_file = "./cache/datasets/mini_gld_train_split.csv"
    fed_g23k_test_map_file = "./cache/datasets/mini_gld_test.csv"

    fed_g160k_train_map_file = (
        "./cache/datasets/landmarks-user-160k/federated_train.csv"
    )
    fed_g160k_map_file = "./cache/datasets/landmarks-user-160k/test.csv"

    dataset_name = "g160k"

    if dataset_name == "g23k":
        client_number = 233
        fed_train_map_file = fed_g23k_train_map_file
        fed_test_map_file = fed_g23k_test_map_file
    elif dataset_name == "g160k":
        client_number = 1262
        fed_train_map_file = fed_g160k_train_map_file
        fed_test_map_file = fed_g160k_map_file

    (
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        class_num,
    ) = load_partition_data_landmarks(
        None,
        data_dir,
        fed_train_map_file,
        fed_test_map_file,
        partition_method=None,
        partition_alpha=None,
        client_number=client_number,
        batch_size=10,
    )

    print(train_data_num, test_data_num, class_num)
    print(data_local_num_dict)

    i = 0
    for data, label in train_data_global:
        print(data)
        print(label)
        i += 1
        if i > 5:
            break
    print("=============================\n")

    flag = True
    for client_idx in range(client_number):
        for i, (data, label) in enumerate(train_data_local_dict[client_idx]):
            print("client_idx %d has %s-th data" % (client_idx, i))

    # flag = True
    # for client_idx in range(client_number):
    #     for i, (data, label) in enumerate(test_data_local_dict[client_idx]):
    #         print("client_idx %d has %s-th data" % (client_idx, i))
