import pandas as pd



def scatter_z(z, y, ax, label, marker, alpha=1.0):
    scatter = ax.scatter(z[:, 0], z[:, 1], c=y,
                         label=label, marker=marker, alpha=alpha,
                         cmap='viridis'
                         )
    return scatter


df1 = pd.read_csv('clip_500_gt_df.csv', index_col=0)
df2 = pd.read_csv('clip_500_pred_df.csv', index_col=0)

df = pd.merge(df1, df2, on='d_name')
##df.to_csv('clip_500_merged.csv')


relative_tsne_nmi_precision = df['pred_tsne_nmi'] / df['gt_tsne_nmi']
relative_tsne_nmi_gap =  df['gt_tsne_nmi'] - df['pred_tsne_nmi']

relative_umap_nmi_precision = df['pred_umap_nmi'] / df['gt_umap_nmi']
relative_umap_nmi_gap =  df['gt_umap_nmi'] - df['pred_umap_nmi']

df['relative_tsne_nmi_precision'] = relative_tsne_nmi_precision
df['relative_tsne_nmi_gap'] = relative_tsne_nmi_gap
df['relative_umap_nmi_precision'] = relative_umap_nmi_precision
df['relative_umap_nmi_gap'] = relative_umap_nmi_gap

f_name = './res/GNN-sign-sigma_mv-run4-500epoch-kl_t_64dim-lap_only_complete_G_sep_gnn-epoch-499.tar'
net, save = torch.load(f_name)

print(net)

print(save['running_loss'][-10:])
##    plt.plot(save['train_mse'], label='train MSE')
##    plt.plot(save['test_mse'], label='test MSE')
##
##    plt.xlabel('epoch')
##    plt.ylabel('MSE')
##
##    plt.title('training with y normalization')
##    plt.legend()
##    plt.show()


train_names = ['mnist_group2', 'mnist_group1', 'fmnist_group2', 'fmnist_group1', 'cifar10_group1'] + [f'mnist_comb{i}'
                                                                                                      for i in
                                                                                                      range(128)]

test_names = [
    'cifar10_group2'
    # 'mnist_group2',
    # 'seismic', 'musk', 'speech', 'abalone'
]

train_ds = CompleteMVGraphDatasets.DatasetGraphDataset(data_names=train_names,
                                                       cdist_path='../prepare_data/clip/features',
                                                       visual_path='../prepare_data/bo/res-2', normalize_z=True,
                                                       precomputed_pe_path='../prepare_data/pe/pe_for_gat',
                                                       z_mu=None, z_std=None)

# test_ds = GraphDatasets.DatasetGraphDataset(data_names=test_names, cdist_path='../prepare_data/clip/features',
#                                             visual_path='../prepare_data/bo/res-2', normalize_z=True,
#                                             precomputed_pe_path='../prepare_data/pe/pe_for_gat',
#                                             z_mu=train_ds.z_mu, z_std=train_ds.z_std)

test_ds = CompleteMVGraphDatasets.DatasetGraphDataset(data_names=test_names, cdist_path='../prepare_data/clip/features',
                                                      visual_path='../prepare_data/bo/res-2', normalize_z=True,
                                                      z_cali_method='none',
                                                      # z_anchor=dummy_train_ds.z[0],
                                                      precomputed_pe_path='../prepare_data/pe/pe_for_gat',
                                                      z_mu=train_ds.z_mu, z_std=train_ds.z_std)

get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                     shuffle=True,
                     num_workers=0,
                     collate_fn=lambda x: list(zip(*x))
                     )

train_z = train_ds.z[0]
test_z = test_ds.z[0]
m = np.matmul(np.mean(test_z, axis=0).reshape(2, 1), np.mean(train_z, axis=0).reshape(1, 2))
u, s, vt = np.linalg.svd(m)
r = np.matmul(u, vt)

test_ds = test_ds
test_loader = get_loader(test_ds)

y_test = test_ds.z[0]

##U, S, Vt = np.linalg.svd(y_test)
##S_matrix = np.diag(S)
##y_test = np.dot(U[:, :2], S_matrix[:2, :2])

fig = plt.figure(figsize=(10, 8))
fig.set(tight_layout=True)
ax = fig.add_subplot()

scatter_z(y_test, test_ds.y[0], ax,
          label='z: ori', marker='+', alpha=0.5
          ##                     cmap='viridis'
          )

##y_test =  test_ds.out_data @ r

##U, S, Vt = np.linalg.svd(y_test)
##S_matrix = np.diag(S)
##y_test = np.dot(U[:, :2] , S_matrix[:2, :2]) * np.array([-1, 1])
##scatter = ax.scatter(y_test[:, 0], y_test[:, 1], c='b',
##                     alpha=0.6,
##                     label='z: svd (-1, 1)', marker='o',
##                     cmap='rainbow'
##                     )


device = 'cuda:0'
net = net.to(device)
# net.eval()
y_hat = []
with torch.no_grad():
    for xs, n_view_graphs, _, zdist, _ in test_loader:
        # inp = inp.to(torch.float32).to(device)
        # z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)

        # z_hat = net(cdist, n_view_graphs)
        z_hat = net(xs, n_view_graphs)

        zdist = zdist[0].to(torch.float32).to(device)
        student_t_loss = net.loss_fn_kl_t(xs, n_view_graphs, zdist)

        y_hat.append(z_hat)

        print(student_t_loss)

y_hat = torch.concatenate(y_hat).cpu().numpy()
##y_hat = y_hat * train_ds.y_std + train_ds.y_mu

##    scatter = ax.scatter(y_hat[:, 0], y_hat[:, 1], c='b',
##                         label='y_hat', marker='o')

s = scatter_z(y_hat, test_ds.y[0], ax,
              label='z_hat', marker='o', alpha=1.0,
              ##                     cmap='viridis'
              )
nmi, sc = visualization_metric.get_nmi_sc(y_hat, test_ds.y[0])
print(f'NMI: {nmi:.4f}, SC: {sc:.4f}')

plt.colorbar(s)
plt.title(f'Visualization of {test_names[0]} CLIP Feature')
plt.legend()
plt.show()