import sys
sys.path.append('/mnt/data01/****/****')

import torch
from cuml import UMAP
from parametric_umap import ParametricUMAP

from gnn.useful_utils import visualization_metric


def train_parametricUMAP(train_x, pumap):
    # Initialize and fit the model

    # Fit and transform the data
    # embeddings = pumap.fit_transform(train_x)
    embeddings = pumap.fit_transform(train_x)
    return pumap, embeddings

    # # Transform new data
    # X_new = np.random.rand(100, 3)
    # new_embeddings = pumap.transform(X_new)


def load_data(d_name):
    features, labels = torch.load(f"/home/****/autovisual/prepare_data/data/{d_name}_features_clip.tar",
                                  weights_only=False)
    return features, labels

if __name__ == "__main__":
    print(123)
    train_x, train_y = load_data('mnist')
    test_x, test_y = load_data('cifar10')
    print(train_y[:3000])

    tsne = UMAP(n_neighbors=15, n_components=2, verbose=0, init='spectral', random_state=42)
    train_z0 = tsne.fit_transform(train_x[:3000])
    nmi, sc = visualization_metric.get_nmi_sc(train_z0, train_y[:3000].tolist())
    print('train true', nmi, sc)

    pumap = ParametricUMAP(
        device='cuda:0',
        n_components=2,
        hidden_dim=256,
        n_layers=10,
        use_batchnorm=True,
        batch_size=128,
        n_epochs=100
    )

    pumap, train_z = train_parametricUMAP(train_x[:3000])
    nmi, sc = visualization_metric.get_nmi_sc(train_z, train_y[:3000].tolist())
    print('train', nmi, sc)

    test_x, test_y = test_x[:3000], test_y[:3000]
    test_y_tsne = UMAP(n_neighbors=15, n_components=2, verbose=0, init='spectral', random_state=42).fit_transform(
        test_x)
    nmi, sc = visualization_metric.get_nmi_sc(test_y_tsne, test_y[:3000].tolist())
    print('test true', nmi, sc)

    test_z = pumap.transform(test_x)
    nmi, sc = visualization_metric.get_nmi_sc(test_z, test_y[:3000].tolist())
    print('test', nmi, sc)

    torch.save(pumap, './PUMAP.save')