#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 23 16:37:42 2024

@author: anonymous
"""


import torch_geometric
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import torch
import torch.nn as nn
from torch.autograd import Function
import numpy as np
import os
import argparse
import networkx as nx
from GraphRicciCurvature.FormanRicci import FormanRicci
from scipy.sparse.linalg import lsqr, cg, eigsh
from scipy.linalg import eigh
import shutil
import tqdm
import pickle


###############################################################################


def pre_process_edges(edge_index):
    """
    From https://github.com/TDA-Jyamiti/GRIL
    Preprocesses edges of graph 
    
    Parameters
    ----------
    edge_index : torch_geometric edge index
                    edges of the graph in torch_geometric format
    
    Returns
    -------
    Torch tensor
        preprocessed edges 

    """
    e = edge_index.permute(1, 0)
    e = e.sort(1)
    e = e[0].tolist()
    e = set([tuple(ee) for ee in e])
    return torch.tensor([ee for ee in e], dtype=torch.long, device=edge_index.device)


###############################################################################


def get_hks(L, K, ts):
    """
    From https://github.com/ctralie/pyhks/blob/master/hks.py
    
    Parameters
    ----------
    L : Graph Laplacian

    K : int
        Number of eigenvalues/eigenvectors to use
    ts : ndarray (T)
        The time scales at which to compute the HKS
    
    Returns
    -------
    hks : ndarray (N, T)
        A array of the heat kernel signatures at each of N points
        at T time intervals
    """
    (eigvalues, eigvectors) = eigh(L)
    res = (eigvectors[:, :, None]**2)*np.exp(-eigvalues[None, :, None] * ts.flatten()[None, None, :])
    return np.sum(res, 1)


###############################################################################



def get_hks_rc_bifiltration(num_nodes, edge_index, nn_k=6):
    """
    From https://github.com/TDA-Jyamiti/GRIL
    Computes the heat kernel signature - Ricci curvature - bilfiltration 
    
    Parameters
    ----------
    num_nodes : Int
                    number of nodes in graph

    edge_index : torch_geometric edge index
                    edges of the graph in torch_geometric format
    
    Returns
    -------
    filt : Torch tensor
            list of bidegrees of the the bifiltration
    edges : Torch tensor
                preprocessed edges in the bifiltration
    """
    g = nx.Graph()
    g.add_nodes_from(range(num_nodes))
    edges = pre_process_edges(edge_index)
    edges_list = edges.tolist()
    g.add_edges_from((e[0], e[1]) for e in edges_list)
    frc = FormanRicci(g)
    frc.compute_ricci_curvature()
    graph_laplacian = nx.normalized_laplacian_matrix(g).toarray().astype(float)
    hks = get_hks(graph_laplacian, num_nodes, ts=np.array([1, 10]))
    f_v_x = hks[:, -1]
    f = []

    for n in range(num_nodes):
        v_curv = frc.G.nodes[n]['formanCurvature']
        f.append([f_v_x[n],v_curv])
        # print(f"Node: {n} f_x: {f_v_x[n]} f_y: {v_curv}")
    f = np.array(f)
    f = (f - f.min(axis=0)) / (f.max(0) - f.min(0) + 1e-4)
    f_e = []
    for e in edges_list:
        e_x = max([f[e[0], 0], f[e[1], 0]]) 
        e_curv = frc.G[e[0]][e[1]]["formanCurvature"]
        e_y = max([f[e[0], 1], f[e[1], 1], e_curv])
        f_e.append([e_x, e_y])
        # print(f"Edge: ({e[0]}, {e[1]}) f_x: {e_x} f_y: {e_y} e_curv: {e_curv}")
    f_e = np.array(f_e)
    # f_e = (f_e - f_e.min(axis=0)) / (f_e.max(0) - f_e.min(0))
    f_e = f_e + 1e-4
    f = np.row_stack((f, f_e))
    filt = torch.tensor(f, device=edge_index.device)
    
    return filt, edges


###############################################################################



def to_mpfree_inp(data):
    """
    Computes the heat kernel signature - Ricci curvature - bilfiltration of the
    graphs in a torch_geometric graph dataset and writes them to files which are
    suitable inputs for mpfree. 
    
    Parameters
    ----------
    data : List of torch_geometric graphs
            list of graphs used to compute bifiltrations
    """
    
    print('\n')
    print('Write to files','\n')
    directory=os.path.join('Data','Graph_Bifiltrations')
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)
    
    labels=[]
    
    for j in tqdm.tqdm(range(len(data))):
        G=data[j]
        num_nodes=G.x.shape[0]
        edge=G.edge_index
        bifilt=get_hks_rc_bifiltration(num_nodes, edge)
        labels.append(G.y.item())
        
        with open('Data/Graph_Bifiltrations/'+str(j)+'.txt', 'w') as f:
            out='scc2020'+'\n'
            f.write(out)
            out='2'+'\n'
            f.write(out)
            out=str(int(edge.shape[1]/2))+' '+str(num_nodes)+' 0 '+'\n'
            f.write(out)
        
            for i in range(len(bifilt[0])-num_nodes):
                out=str(bifilt[0][num_nodes+i,0].item())+' '+str(bifilt[0][num_nodes+i,1].item())+' ; '+str(bifilt[1][i,0].item())+' '+str(bifilt[1][i,1].item())+'\n'
                f.write(out)
                
            for i in range(num_nodes):
                out=str(bifilt[0][i,0].item())+' '+str(bifilt[0][i,1].item())+' ; '+'\n'
                f.write(out)

    pickle.dump(labels,open('Data/graph_labels.txt', 'wb'))


###############################################################################

"""TUDatasets: """

# name='COX2'
# name='DHFR'
# name='IMBD-BINARY'
# name='MUTAG'
name='PROTEINS'

data=TUDataset(root='Data/TUDatasets/', name=name)

to_mpfree_inp(data)



