from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import datetime
import os
import logging

import numpy as np
import pandas as pd

import torch

import create_adj

import posenc_stats
import posenc_config


def extract_graph_pe(data):
    # Convert adjacency matrix to edge indices


    pes = ['LapPE',
           # 'EquivStableLapPE',
           'SignNet',
           'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           'RRWP',
           # 'SVD', 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]
    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


def normalize_cdist(cdist):
    n, _ = cdist.shape
    up_tri = cdist[np.triu_indices_from(cdist, k=1)]
    mean = np.mean(up_tri)
    std = np.std(up_tri)
    cdist = (cdist - mean) / (std + 1e-9)
    cdist = cdist - np.eye(n) * (- mean / (std + 1e-9))
    return cdist


if __name__ == '__main__':
    set_cfg(cfg)
    posenc_config.set_cfg_posenc(cfg)

    datasets = [
        ('fmnist_clip_cdist_3000_group1.tar', 'fmnist_group1'),
        ('fmnist_clip_cdist_3000_group2.tar', 'fmnist_group2'),
        ('mnist_clip_cdist_3000_group1.tar', 'mnist_group1'),
        ('mnist_clip_cdist_3000_group2.tar', 'mnist_group2'),
        ('cifar10_clip_cdist_3000_group1.tar', 'cifar10_group1'),
        ('cifar10_clip_cdist_3000_group2.tar', 'cifar10_group2')
    ]
    root = 'H:\\****\\phd\\****\\autoVisual\\prepare_data\\clip\\features'
    for dataset_path, dataset in datasets:
        # load data
        cdist, x, y = torch.load(root + '/' + dataset_path)
        # cdist = normalize_cdist(cdist)

        all_graph_emb = []
        for k in [
            1,
                  5, 10, 25]:
            # convert cdist into adj
            graph = create_adj.create_adj_from_dist(cdist, method='knn', k=k)
            graph_emb = extract_graph_pe(graph).numpy()
            all_graph_emb.append(graph_emb)
            print(graph_emb.shape)

        for sigma in [
            # 0.1,
            0.5,
                      1, 5, 10
                      ]:
            # convert cdist into adj
            graph = create_adj.create_adj_from_dist(cdist, method='gaussian', sigma=sigma)
            graph_emb = extract_graph_pe(graph).numpy()
            all_graph_emb.append(graph_emb)
            print(graph_emb.shape)

        all_graph_emb = np.concatenate(all_graph_emb, axis=1)

        print(all_graph_emb.shape)
        torch.save((all_graph_emb, y), f'./pe_data/{dataset}_pe_5_weight.pth')
