import torch

import numpy as np
import pandas as pd
from parametric_umap import ParametricUMAP
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import load_data_for_baselines
from gnn.useful_utils import visualization_metric
import ae_keras


def evaluate_pca(zhat, block_size, ys):
    nmis = []
    scs = []
    block_size = torch.tensor(block_size)
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), block_size]), dim=0).long()
    for i in range(len(block_size)):
        _zhat = zhat[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        _y = ys[i]
        # print(_y)

        _nmi, _sc = visualization_metric.get_nmi_sc(_zhat, _y.tolist())
        nmis.append(_nmi)
        scs.append(_sc)
    return nmis, scs


print('loading')
train_d, test_d = torch.load('./clip_datas_for_baselines.tar', weights_only=False)
print('finish loading')

xs, tsne_zs, umap_zs, ys, d_names = train_d
block_sizes = [i.shape[0] for i in xs]
train_x = np.concatenate(xs)
tsne_zs = np.concatenate(tsne_zs)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(train_x)

ae, encoder = ae_keras.build_model(xs[0].shape[1], 2)
ae = ae_keras.train_ae(train_x, ae)
pca_train_zhat = ae_keras.transform(encoder, train_x)
train_nmis, train_scs = evaluate_pca(pca_train_zhat, block_sizes, ys)


test_xs, _, _, test_ys, test_d_names = test_d
test_block_sizes = [i.shape[0] for i in test_xs]
test_x = np.concatenate(test_xs)
test_x = scaler.fit_transform(test_x)
pca_test_zhat = ae_keras.transform(encoder, test_x)
test_nmis, test_scs = evaluate_pca(pca_test_zhat, test_block_sizes, test_ys)
torch.save((train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, (ae, encoder)), './ae_res.save')


print(np.mean(train_nmis), np.mean(test_nmis))



