import sys
sys.path.append('/mnt/data01/****/****')

import torch

import numpy as np


def get_clip_data_by_names(d_names):
    visual_path = '/mnt/data01/public/aad_data/bo'
    visual_path_umap = '/mnt/data01/public/aad_data/bo/umap'
    xs = []
    tsne_zs = []
    umap_zs = []
    ys = []
    names = []
    for d_name in d_names:
        dataset_name = '-'.join(d_name.split('-')[:-2])
        dataset_path = f'/mnt/data01/public/aad_data/{dataset_name}/{d_name}.tar'
        cdist, y, ind, selected_labels = torch.load(dataset_path, weights_only=False)
        # print(y)
        features, labels = torch.load(f"/home/****/autovisual/prepare_data/data/{dataset_name}_features_clip.tar", weights_only=False)

        selected_indices = np.isin(labels, selected_labels)
        # print(len(selected_indices))
        x_selected = features[selected_indices]
        # print(len(x_selected))
        # y_selected = labels[selected_indices]

        # print(labels, selected_labels, ind)
        x = x_selected[ind]

        # tsne
        selected_emb, hps = torch.load(
            visual_path + '/' + f'{dataset_name}/visual-method-TSNE_dataset-{d_name}_selected_emb.tar')

        tsne_z = torch.from_numpy(selected_emb).to(torch.float32)

        # umap
        selected_emb, hps = torch.load(
            visual_path_umap + '/' + f'{dataset_name}/visual-method-UMAP_dataset-{d_name}_selected_emb.tar')
        umap_z = torch.from_numpy(selected_emb).to(torch.float32)

        xs.append(x)
        umap_zs.append(umap_z)
        tsne_zs.append(tsne_z)
        ys.append(y)
        names.append(d_name)
    return xs, tsne_zs, umap_zs, ys, names


def get_model_ckp(ckp_name):
    net, save = torch.load(ckp_name)
    return net, save

def get_d_names_from_save(save):
    train_names = save['train_names']
    test_names = save['test_names']
    return train_names, test_names


def get_clip_data_from_ckp(ckp_name):
    net, save = get_model_ckp(ckp_name)
    train_names, test_names = get_d_names_from_save(save)
    train_d = get_clip_data_by_names(train_names)
    test_d = get_clip_data_by_names(test_names)
    return train_d, test_d


if __name__ == '__main__':
    root = '/mnt/data01/****/****/gnn/res/'
               # '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar'
    ckp_name = '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-seqgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar'
    ckp_name = root + '/' + ckp_name
    # net, save = get_model_ckp(ckp_name)

    # train_names, test_names = get_d_names_from_save(save)

    print('loading')
    train_d, test_d = get_clip_data_from_ckp(ckp_name)

    print('finish loading')
    torch.save((train_d, test_d), './clip_datas_for_baselines.tar')
