from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import datetime
import os
import logging

import numpy as np
import pandas as pd

import torch

import create_adj

import posenc_stats
import posenc_config


def extract_graph_pe(a):
    # Convert adjacency matrix to edge indices
    row, col = np.where(a > 0)  # Find indices of non-zero elements
    edge_index = torch.tensor([row, col], dtype=torch.long)  # Convert to PyTorch tensor

    # Create PyG graph
    num_nodes = a.shape[0]
    data = Data(edge_index=edge_index, num_nodes=num_nodes)

    pes = ['LapPE', 'EquivStableLapPE', 'SignNet',
           'RWSE', 'HKdiagSE', 'ElstaticSE', 'RRWP',
           'SVD', 'PPR', 'WLPE', 'GCKN', 'RWDIFF']

    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # print(f'{pe}_pos_enc', emb.shape)
        final_emb.append(emb)
    final_emb = np.concatenate(final_emb, axis=1)
    return final_emb


def normalize_cdist(cdist):
    n, _ = cdist.shape
    up_tri = cdist[np.triu_indices_from(cdist, k=1)]
    mean = np.mean(up_tri)
    std = np.std(up_tri)
    cdist = (cdist - mean) / (std + 1e-9)
    cdist = cdist - np.eye(n) * (- mean / (std + 1e-9))
    return cdist


if __name__ == '__main__':
    set_cfg(cfg)
    posenc_config.set_cfg_posenc(cfg)

    for dataset in [
        'arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
                'vowels',
                     'letter',
                     'cardio', 'seismic',
        'musk', 'speech', 'abalone',
        'pendigits', 'mammography',
                'mulcross',
                'forest_cover']:
        # load data
        cdist, x, y = torch.load(f'G:\\****\\ad\\datasets\cdist_data/{dataset}_cdist_data_for_tsne2.tar')
        cdist = normalize_cdist(cdist)

        all_graph_emb = []
        for k in [1, 5, 10, 25]:
            # convert cdist into adj
            adj_matrix_ = create_adj.create_adj_from_dist(cdist, method='knn', k=k)
            graph_emb = extract_graph_pe(adj_matrix_)
            all_graph_emb.append(graph_emb)

        for sigma in [0.1, 0.5,
                      1, 5, 10
                      ]:
            # convert cdist into adj
            adj_matrix_ = create_adj.create_adj_from_dist(cdist, method='gaussian', sigma=sigma)
            graph_emb = extract_graph_pe(adj_matrix_)
            all_graph_emb.append(graph_emb)

        all_graph_emb = np.concatenate(all_graph_emb, axis=1)

        print(all_graph_emb.shape)
        torch.save(all_graph_emb, f'./pe_data/{dataset}_pe_2.pth')
