import torch

import numpy as np
import pandas as pd
from parametric_umap import ParametricUMAP
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import load_data_for_baselines
from gnn.useful_utils import visualization_metric


def evaluate_zhat(zhat, block_size, ys):
    nmis = []
    scs = []
    block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), block_size]), dim=0).long()
    for i in range(len(block_size)):
        _zhat = zhat[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        _y = ys[block_size_cumsum[i]: block_size_cumsum[i + 1]]
        _nmi, _sc = visualization_metric.get_nmi_sc(zhat, _y)
        nmis.append(_nmi)
        scs.append(_sc)
    return nmis, scs


def evaluate_zhat_one_by_one(zhat, ys):
    nmis = []
    scs = []
    for i in range(len(zhat)):
        _zhat = zhat[i]
        _y = ys[i]
        _nmi, _sc = visualization_metric.get_nmi_sc(_zhat, _y)
        nmis.append(_nmi)
        scs.append(_sc)
    return nmis, scs


def train_pumap(pumap_obj, xs):
    n_data = len(xs)
    for i in range(n_data):
        x_ = xs[i]
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(x_)
        pumap_obj = pumap_obj.fit(X_scaled)
    return pumap_obj


def eval_pumap(pumap_obj, xs):
    n_data = len(xs)
    zhats = []
    for i in range(n_data):
        x_ = xs[i]
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(x_)
        zhat_ = pumap_obj.transform(X_scaled)
        zhats.append(zhat_)
    return zhats

print('loading')
train_d, test_d = torch.load('./clip_datas_for_baselines.tar', weights_only=False)
print('finish loading')

xs, tsne_zs, umap_zs, ys, d_names = train_d
block_sizes = [i.shape[0] for i in xs]

pumap = ParametricUMAP(
        device='cuda:0',
        n_components=2,
        hidden_dim=256,
        n_layers=10,
        use_batchnorm=True,
        batch_size=3000,
        n_epochs=50
    )

pumap = train_pumap(pumap, xs)


pca_train_zhat = eval_pumap(pumap, xs)
train_nmis, train_scs = evaluate_zhat_one_by_one(pca_train_zhat, ys)

test_xs, _, _, test_ys, test_d_names = test_d
pca_test_zhat = eval_pumap(pumap, test_xs)

test_nmis, test_scs = evaluate_zhat_one_by_one(pca_test_zhat, test_ys)
torch.save((train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, pumap), './pumap2_res2.save')



