import torch

import numpy as np
import pandas as pd
from parametric_umap import ParametricUMAP
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import load_data_for_baselines
from gnn.useful_utils import visualization_metric


def evaluate_pca(zhat, block_size, ys):
    nmis = []
    scs = []
    block_size = torch.tensor(block_size)
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), block_size]), dim=0).long()
    for i in range(len(block_size)):
        _zhat = zhat[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        _y = ys[i]
        # print(_y)

        _nmi, _sc = visualization_metric.get_nmi_sc(_zhat, _y.tolist())
        nmis.append(_nmi)
        scs.append(_sc)
    return nmis, scs


print('loading')
train_d, test_d = torch.load('./clip_datas_for_baselines.tar', weights_only=False)
print('finish loading')

xs, tsne_zs, umap_zs, ys, d_names = train_d
block_sizes = [i.shape[0] for i in xs]
train_x = np.concatenate(xs)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(train_x)

pumap = ParametricUMAP(
        device='cuda:0',
        n_components=2,
        hidden_dim=256,
        n_layers=10,
        use_batchnorm=True,
        batch_size=256,
        n_epochs=100
    )

pca_train_zhat = pumap.fit_transform(X_scaled)

train_nmis, train_scs = evaluate_pca(pca_train_zhat, block_sizes, ys)


test_xs, _, _, test_ys, test_d_names = test_d
test_block_sizes = [i.shape[0] for i in test_xs]
test_x = np.concatenate(test_xs)
test_x = scaler.fit_transform(test_x)
pca_test_zhat = pumap.transform(test_x)
test_nmis, test_scs = evaluate_pca(pca_test_zhat, test_block_sizes, test_ys)
torch.save((train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, pumap), './pumap_res.save')



