import torch

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import load_data_for_baselines
from gnn.useful_utils import visualization_metric


def evaluate_pca(zhat, block_size, ys):
    nmis = []
    scs = []
    block_size = torch.tensor(block_size)
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), block_size]), dim=0).long()
    for i in range(len(block_size)):
        _zhat = zhat[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        _y = ys[i]
        # print(_y)

        _nmi, _sc = visualization_metric.get_nmi_sc(_zhat, _y.tolist())
        nmis.append(_nmi)
        scs.append(_sc)
    return nmis, scs

root = '/mnt/data01/****/****/gnn/res/'
ckp_name = '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-seqgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar'
ckp_name = root + '/' + ckp_name

print('loading')
train_d, test_d = torch.load('./clip_datas_for_baselines.tar', weights_only=False)

print('finish loading')

xs, tsne_zs, umap_zs, ys, d_names = train_d
block_sizes = [i.shape[0] for i in xs]
train_x = np.concatenate(xs)

print(train_x.shape)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(train_x)

pca = PCA(n_components=2)
pca_train_zhat = pca.fit_transform(X_scaled)

train_nmis, train_scs = evaluate_pca(pca_train_zhat, block_sizes, ys)

test_xs, _, _, test_ys, test_d_names = test_d
test_block_sizes = [i.shape[0] for i in test_xs]
test_x = np.concatenate(test_xs)
test_x = scaler.fit_transform(test_x)
pca_test_zhat = pca.transform(test_x)
test_nmis, test_scs = evaluate_pca(pca_test_zhat, test_block_sizes, test_ys)
torch.save((train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, pca), './pca_res.save')

print(np.mean(train_nmis), np.mean(test_nmis))



