# ------------------------ : new rwpse ----------------
from typing import Union, Any, Optional
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
from torch_scatter import scatter, scatter_add, scatter_max
from torch_geometric.utils.convert import to_networkx
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils.convert import from_scipy_sparse_matrix

from torch_geometric.graphgym.config import cfg
import networkx as nx 
from torch_geometric.utils import (
    get_laplacian,
    get_self_loop_attr,
    to_scipy_sparse_matrix,
)
import torch_sparse
from torch_sparse import SparseTensor

def graph_from_edge_index(data,max_dim):
    edges = list(set([(min(u,v),max(u,v)) for u,v in zip(data.edge_index[0].tolist(),data.edge_index[1].tolist())]))
    G = nx.Graph()
    G.add_edges_from(edges)
    for node in range(len(G),max_dim) : 
        G.add_node(node)
    return G 

def common_num_tensor(G):
    
    node_list = list(G.nodes())
    num_common =  torch.diag(torch.tensor(list(dict(G.degree()).values())))
    for i in range(len(node_list)):
        for j in range(i+1,len(node_list)):
            num_common[i][j] += len(list(nx.common_neighbors(G, node_list[i], node_list[j])))
            num_common[j][i] += num_common[i][j]
    return num_common

def dist_exp_tensor(G):
    dist =  torch.exp(torch.tensor(-nx.floyd_warshall_numpy(G)))
    return dist


def add_node_attr(data: Data, value: Any,
                  attr_name: Optional[str] = None) -> Data:
    if attr_name is None:
        if 'x' in data:
            x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
        else:
            data.x = value
    else:
        data[attr_name] = value

    return data

def compute_common_nodes(data):
    G = to_networkx(data,to_undirected = True)
    
    node_list = list(G.nodes())
    num_common =  torch.diag(torch.tensor(list(dict(G.degree()).values())))
    for i in range(len(node_list)):
        for j in range(i+1,len(node_list)): 
            num_common[i][j] += len(list(nx.common_neighbors(G, node_list[i], node_list[j])))
            num_common[j][i] += num_common[i][j]
    
    num_common = SparseTensor.from_dense(num_common, has_value=True)
    num_common_row, num_common_col, num_common_val = num_common.coo()
    num_common_idx = torch.stack([num_common_row, num_common_col], dim=1)
    data['num_common_idx'] = num_common_idx
    data['num_common_val'] = num_common_val

    return data
    

class CorrelationMatrix:

    def __init__(self,
                h : float,
                theta: torch.Tensor,
                t: torch.Tensor,
                device='cpu',
                 ) -> torch.Tensor:
        super().__init__()
        self.device = device
        self.h = h.to(self.device)
        self.t = t.to(self.device)
        self.theta = theta.to(self.device)
  

    def w_ij_batched(self, Adj):
        # Adj = to_dense_adj(data.edge_index)

        return torch.cos(self.theta) ** 2 + (torch.sin(self.theta) **2 )*\
        torch.exp(Adj * self.t *1j).to(self.device)
    

    def w_plus_batched(self, Adj, com_ij):
        # com_ij = to_dense_adj(data.num_common_idx.t(), edge_attr=data.num_common_val)
        # Adj = to_dense_adj(data.edge_index)


        B = (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(1j * self.t)) \
        * (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(-1j * self.t))

        return torch.pow((torch.cos(self.theta)**2 + (torch.sin(self.theta)**2) *
                          torch.exp(2 * self.t * 1j)), com_ij).to(self.device) * (B**(1-Adj))
    

    def w_minus_batched(self, Adj):
        # Adj = to_dense_adj(data.edge_index)
        # 1 if neighbors else (cos2 + sin2 e(it))(cos2 + sin2 e(-it))
        return (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(1j * self.t)) \
        * (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(-1j * self.t)) ** (1 - Adj)
            

    def compute_correlation_matrix_batched(self, data):
        # D = data.num_common.shape[1]
        # BS = data.batch.max().item()+1
        # Adj_batched = self.edge_index_to_batch_adj(data)
        Adj = to_dense_adj(data.edge_index).to(self.device)
        com_ij = to_dense_adj(data.num_common_idx.t(), edge_attr=data.num_common_val).to(self.device)
        
        
        F = ((4*(torch.sin(self.theta)**4) * (torch.cos(self.theta)**4))).to(self.device)
        W = self.w_ij_batched(Adj)#.to(self.device)

        rho_vect = torch.prod(W, 2) * torch.exp(self.h * self.t * 1j).to(self.device)
        #rho_col = ((rho_vect.reshape(BS ,D, 1).repeat(1, 1, D)).reshape(-1, D)).to(self.device)
        #rho_row = (torch.repeat_interleave(rho_vect,D,dim = 0)).to(self.device)
        rho_ij = (rho_vect[:, None] + rho_vect.t()[None, :])#.reshape(BS, D, D)).to(self.device)

        f1 = (rho_ij * (1 - 1/W)).to(self.device) #####
        a = (.5 * (1 - (torch.exp(Adj * self.t * 1j) / self.w_plus_batched(Adj, com_ij)))).to(self.device)
        # b = ((rho_vect[:,:,None]).repeat(1, 1, D).reshape(BS,D,D)\
        #          *torch.repeat_interleave(rho_vect,D,dim = 0 ).reshape(BS, D, D)).to(self.device)
        b = (rho_vect[:, None] * rho_vect.t()[None, :])
        f2 = a * b
        b_conj = (rho_vect[:, None] * torch.conj(rho_vect.t())[None, :])

        f3 = .5 * (1 - (1 / self.w_minus_batched(Adj))) * b_conj

        return ((F * torch.real(f1 + f2 + f3))).to(self.device).squeeze(0).cpu()

@torch.no_grad()
def add_full_rrwp_static(data,
                  walk_length=8,
                  attr_name_abs="rrwp", # name: 'rrwp'
                  attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val')
                  add_identity=True,
                  spd=False,
                  n_quantum=0,
                  **kwargs
                  ):
    device=data.edge_index.device



    ind_vec = torch.eye(walk_length, dtype=torch.float, device=device)
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, data.edge_weight

    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                       sparse_sizes=(num_nodes, num_nodes),
                                       )

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj = adj * deg_inv.view(-1, 1)
    adj = adj.to_dense()

    pe_list = []
    i = 0
    if add_identity:
        pe_list.append(torch.eye(num_nodes, dtype=torch.float))
        i = i + 1

    out = adj
    pe_list.append(adj)

    if walk_length > 2:
        for j in range(i + 1, walk_length):
            out = out @ adj
            pe_list.append(out)
    if n_quantum >0:
        q_params = torch.load('data/q_params.pt')
        data = compute_common_nodes(data)
        for i in range(n_quantum):
            correlation = CorrelationMatrix(torch.ones(data.num_nodes),
                                            torch.tensor(q_params[i, 0]), torch.tensor(q_params[i, 1]), device=torch.device(cfg.accelerator))
            pe_list.append(correlation.compute_correlation_matrix_batched(data))

    pe = torch.stack(pe_list, dim=-1) # n x n x k

    abs_pe = pe.diagonal().transpose(0, 1) # n x k

    rel_pe = SparseTensor.from_dense(pe, has_value=True)
    rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
    rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)

    if spd:
        spd_idx = walk_length - torch.arange(walk_length)
        val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0)
        val = torch.argmax(val, dim=-1)
        rel_pe_val = F.one_hot(val, walk_length).type(torch.float)
        abs_pe = torch.zeros_like(abs_pe)

    data = add_node_attr(data, abs_pe, attr_name=attr_name_abs)
    data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index")
    data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val")

    #todo : fix automatically the maximum dimension
    G = graph_from_edge_index(data,37)
    data = add_node_attr(data, common_num_tensor(G), attr_name="num_common")
    data = add_node_attr(data, dist_exp_tensor(G), attr_name="dist")

    data.log_deg = torch.log(deg + 1)
    data.deg = deg.type(torch.long)

    

    # Compute the number of common neighbors:
    
    return data

@torch.no_grad()



@torch.no_grad()
def add_full_rrwp_static_random(data,
                  walk_length=8,
                  attr_name_abs="rrwp", # name: 'rrwp'
                  attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val')
                  add_identity=True,
                  randomize_edges = False,
                  randomize_features = False,
                  keep_deg_sequence =  False,
                  spd=False,
                  **kwargs
                  ):

    device=data.edge_index.device
    ind_vec = torch.eye(walk_length, dtype=torch.float, device=device)
    num_nodes = data.num_nodes

    ## Randomize the graph while keeping the degree sequence 
    if randomize_edges:
        if keep_deg_sequence:
            edge_index_ori, edge_weight_ori = data.edge_index, data.edge_weight
            adj_ori = SparseTensor.from_edge_index(edge_index_ori, edge_weight_ori,
                                            sparse_sizes=(num_nodes, num_nodes),
                                            )
            
            deg_sequence = adj_ori.sum(dim=1).int().tolist()
            G_rand = nx.configuration_model(deg_sequence)
            edge_index, edge_weight = from_scipy_sparse_matrix(nx.adjacency_matrix(G_rand))
        else : 
            n = data.x.shape[0]
            m = data.edge_index.shape[1]
            G_rand = nx.gnm_random_graph(n, m)
            edge_index, edge_weight =from_scipy_sparse_matrix(nx.adjacency_matrix(G_rand))
    else : 
        edge_index, edge_weight = data.edge_index, data.edge_weight

    ## Randomize the graph features order 
    if randomize_features : 
        data.x = data.x[torch.randperm(data.x.shape[0])]
    #########################################################

    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                       sparse_sizes=(num_nodes, num_nodes),
                                       )

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj = adj * deg_inv.view(-1, 1)
    adj = adj.to_dense()

    pe_list = []
    i = 0
    if add_identity:
        pe_list.append(torch.eye(num_nodes, dtype=torch.float))
        i = i + 1

    out = adj
    pe_list.append(adj)

    if walk_length > 2:
        for j in range(i + 1, walk_length):
            out = out @ adj
            pe_list.append(out)

    pe = torch.stack(pe_list, dim=-1) # n x n x k

    abs_pe = pe.diagonal().transpose(0, 1) # n x k

    rel_pe = SparseTensor.from_dense(pe, has_value=True)
    rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
    rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)

    if spd:
        spd_idx = walk_length - torch.arange(walk_length)
        val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0)
        val = torch.argmax(val, dim=-1)
        rel_pe_val = F.one_hot(val, walk_length).type(torch.float)
        abs_pe = torch.zeros_like(abs_pe)

    data = add_node_attr(data, abs_pe, attr_name=attr_name_abs)
    data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index")
    data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val")
    data.log_deg = torch.log(deg + 1)
    data.deg = deg.type(torch.long)

    return data