from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import datetime
import os
import logging

import numpy as np
import pandas as pd

import torch

import create_adj

import posenc_stats
import posenc_config


def extract_graph_pe(data):
    # Convert adjacency matrix to edge indices


    pes = ['LapPE',
           # 'EquivStableLapPE',
           'SignNet',
           'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           'RRWP',
           # 'SVD', 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]
    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


def normalize_cdist(cdist):
    n, _ = cdist.shape
    up_tri = cdist[np.triu_indices_from(cdist, k=1)]
    mean = np.mean(up_tri)
    std = np.std(up_tri)
    cdist = (cdist - mean) / (std + 1e-9)
    cdist = cdist - np.eye(n) * (- mean / (std + 1e-9))
    return cdist


if __name__ == '__main__':
    set_cfg(cfg)
    posenc_config.set_cfg_posenc(cfg)

    datasets = [
                # 'mnist',
        # 'fmnist',
        'cifar10']

    root = 'H:\\****\\phd\\****\\autoVisual\\prepare_data\\clip\\features'
    for dataset in datasets:
        for ind in range(117, 252):
            # load data
            cdist, x, y, _ = torch.load(root + '/' + f'{dataset}_comb{ind}_clip_cdist_3000.tar')
            # cdist = normalize_cdist(cdist)

            all_graph_emb = []
            for k in [
                1,
                      5, 10, 25]:
                # convert cdist into adj
                graph = create_adj.create_adj_from_dist(cdist, method='knn', k=k)
                graph_emb = extract_graph_pe(graph).numpy()
                all_graph_emb.append(graph_emb)
                print(graph_emb.shape)

            for sigma in [
                0.1,
                0.5,
                          1, 5, 10
                          ]:
                # convert cdist into adj
                graph = create_adj.create_adj_from_dist(cdist, method='gaussian', sigma=sigma, keep_edge_weight=False)
                graph_emb = extract_graph_pe(graph).numpy()
                all_graph_emb.append(graph_emb)
                print(graph_emb.shape)

            all_graph_emb = np.concatenate(all_graph_emb, axis=1)

            print(all_graph_emb.shape)
            torch.save((all_graph_emb, y), f'./pe_data/{dataset}_comb{ind}_pe_3.pth')
