# ------------------------ : new rwpse ----------------
from typing import Union, Any, Optional
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
from torch_scatter import scatter, scatter_add, scatter_max
from torch_geometric.utils.convert import to_networkx
from torch_geometric.utils import to_dense_adj


from torch_geometric.graphgym.config import cfg
import networkx as nx 
from torch_geometric.utils import (
    get_laplacian,
    get_self_loop_attr,
    to_scipy_sparse_matrix,
)
import torch_sparse
from torch_sparse import SparseTensor

from itertools import combinations

def graph_from_edge_index(data,max_dim):
    edges = list(set([(min(u,v),max(u,v)) for u,v in zip(data.edge_index[0].tolist(),data.edge_index[1].tolist())]))
    G = nx.Graph()
    G.add_edges_from(edges)
    for node in range(len(G),max_dim) : 
        G.add_node(node)
    return G 

def common_num_tensor(G):
    
    node_list = list(G.nodes())
    num_common =  torch.diag(torch.tensor(list(dict(G.degree()).values())))
    for i in range(len(node_list)):
        for j in range(i+1,len(node_list)):
            num_common[i][j] += len(list(nx.common_neighbors(G, node_list[i], node_list[j])))
            num_common[j][i] += num_common[i][j]
    return num_common

def dist_exp_tensor(G):
    dist =  torch.exp(torch.tensor(-nx.floyd_warshall_numpy(G)))
    return dist


def add_node_attr(data: Data, value: Any,
                  attr_name: Optional[str] = None) -> Data:
    if attr_name is None:
        if 'x' in data:
            x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
        else:
            data.x = value
    else:
        data[attr_name] = value

    return data

def compute_common_nodes(data):
    # G = to_networkx(data,to_undirected = True)
    
    # node_list = list(G.nodes())
    # num_common =  torch.diag(torch.tensor(list(dict(G.degree()).values())))
    # for i in range(len(node_list)):
    #     for j in range(i+1,len(node_list)): 
    #         num_common[i][j] += len(list(nx.common_neighbors(G, node_list[i], node_list[j])))
    #         num_common[j][i] += num_common[i][j]
    Adj = to_dense_adj(data.edge_index, max_num_nodes=data.num_nodes).squeeze(0)
    num_common = torch.matmul(Adj, Adj)
    
    num_common = SparseTensor.from_dense(num_common, has_value=True)
    num_common_row, num_common_col, num_common_val = num_common.coo()
    num_common_idx = torch.stack([num_common_row, num_common_col], dim=1)
    data['num_common_idx'] = num_common_idx
    data['num_common_val'] = num_common_val

    return data
    

class CorrelationMatrix:

    def __init__(self,
                h : float,
                theta: torch.Tensor,
                t: torch.Tensor,
                device='cpu',
                 ) -> torch.Tensor:
        super().__init__()
        self.device = device
        self.h = h.to(self.device)
        self.t = t.to(self.device)
        self.theta = theta.to(self.device)

    def w_ij_batched(self, Adj):
        # Adj = to_dense_adj(data.edge_index)

        return torch.cos(self.theta) ** 2 + (torch.sin(self.theta) **2 )*\
        torch.exp(Adj * self.t *1j).to(self.device)
            

    def compute_correlation_matrix_batched(self, data):
        # D = data.num_common.shape[1]
        # BS = data.batch.max().item()+1
        # Adj_batched = self.edge_index_to_batch_adj(data)
        Adj = to_dense_adj(data.edge_index, max_num_nodes=data.num_nodes).to(self.device).squeeze(0)
        Deg = torch.sum(Adj, axis=1)
        Deg_mat = Deg[:, None] + Deg[None, :]
        Deg_row = Deg.unsqueeze(1).repeat(1, data.num_nodes)
        Deg_col = Deg.unsqueeze(0).repeat(data.num_nodes, 1)

        com_ij = to_dense_adj(data.num_common_idx.t(), edge_attr=data.num_common_val, max_num_nodes=data.num_nodes).to(self.device).squeeze(0)

        W = self.w_ij_batched(Adj)#.to(self.device)

        rho_vect = torch.prod(W, 1) * torch.exp(self.t * 1j).to(self.device)
        
        
        F = ((4*(torch.sin(self.theta)**4) * (torch.cos(self.theta)**4))).to(self.device)
        et = (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(1j * self.t))
        e2t = (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(1j * self.t * 2))
        emt = (torch.cos(self.theta)**2 + torch.sin(self.theta)**2 * torch.exp(-1j * self.t))

        f1 = 1 - torch.exp(self.t * 1j) * torch.pow(et, Deg_row - Adj) - torch.exp(self.t * 1j) * torch.pow(et, Deg_col - Adj)
        f2 = .5 * torch.exp(1j * self.t * (Adj +2)) * torch.pow(e2t, com_ij) * torch.pow(et, Deg_mat - com_ij)
        f3 = .5 * torch.pow(e2t, com_ij) * torch.pow(et, Deg_row - Adj) * torch.pow(emt, Deg_col - Adj)
        f4 = 1 - (rho_vect[:, None] + rho_vect[None, :]) + .5 * rho_vect[:, None] * rho_vect.t()[None, :] + .5 * rho_vect[:, None] * torch.conj(rho_vect.t())[None, :]


        return ((F * torch.real(f1 + f2 + f3 - f4))).to(self.device).squeeze(0).cpu()
    

def compute_random_walk_xy_2_matrix(data):
    N = data.num_nodes
    edges = data.edge_index

    if N == 1:
        return torch.tensor([1.]).reshape((1, 1)), torch.tensor([[0], [0]])

    if N <= 180:
        device = cfg.accelerator
    else:
        device='cpu'
    Adj = to_dense_adj(data.edge_index, max_num_nodes=data.num_nodes).squeeze(0).to(device)



    a = torch.tensor(list(combinations(np.arange(N), 2))).to(device)

    index = a.clone()
    b = a.unsqueeze(0).unsqueeze(2)
    a = a.unsqueeze(1).unsqueeze(1)
    Ne = a.size()[0]
    try:
        c = torch.cat([a.repeat(1, Ne, 1, 1), b.repeat(Ne, 1, 1, 1)], dim=2)
    except RuntimeError as e:
        print(e)
    c1 = torch.transpose(c.clone(), 2, 3)

    H = 0
    d = Adj[c1[:, :, :, 0], c1[:, :, :, 1]]
    e = (c1[:, :, :, 0] == c1[:, :, :, 1])*1
    f = d - e
    H += 1*((f[:, :, 0] * f[:, :, 1]) < -.5)

    d = c1[:, :, 0, 1].clone()

    c1[:, :, 0, 1]= c1[:, :, 1, 1].clone()
    c1[:, :, 1, 1]  = d
    del e
    del f
    torch.cuda.empty_cache()

    d = Adj[c1[:, :, :, 0], c1[:, :, :, 1]]
    e = (c1[:, :, :, 0] == c1[:, :, :, 1])*1
    f = d - e
    H += 1*((f[:, :, 0] * f[:, :, 1]) < -.5)
    H = H.float()
    M = torch.matmul(H, torch.diag(1/torch.maximum(torch.sum(H, dim=1), torch.ones(H.shape[0]).to(H.device))))

    return M, index.t()

@torch.no_grad()
def add_full_rrwp(data,
                  walk_length=8,
                  attr_name_abs="rrwp", # name: 'rrwp'
                  attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val')
                  add_identity=True,
                  spd=False,
                  n_quantum=0,
                  n_quantum_xy=0,
                  start_adj=False,
                  **kwargs
                  ):
    device=data.edge_index.device



    ind_vec = torch.eye(walk_length, dtype=torch.float, device=device)
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, data.edge_weight

    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                       sparse_sizes=(num_nodes, num_nodes),
                                       )
    adj_raw = adj.clone()

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj = adj * deg_inv.view(-1, 1)
    adj = adj.to_dense()

    pe_list = []
    i = 0
    if add_identity:
        pe_list.append(torch.eye(num_nodes, dtype=torch.float))
        i = i + 1

    out = adj
    pe_list.append(adj)

    if walk_length > 2:
        for j in range(i + 1, walk_length):
            out = out @ adj
            pe_list.append(out)

    if n_quantum > 0:
        q_params = torch.load('data/q_params.pt')
        data = compute_common_nodes(data)
        for i in range(n_quantum):
            correlation = CorrelationMatrix(torch.ones(data.num_nodes),
                                            torch.tensor(q_params[i, 0]), torch.tensor(q_params[i, 1]), device=torch.device(cfg.accelerator))
            pe_list.append(correlation.compute_correlation_matrix_batched(data))

    if n_quantum_xy > 0:
        M, idx = compute_random_walk_xy_2_matrix(data)

        if start_adj:
            P = adj_raw.to_dense().to(M.device)[idx[0], idx[1]]
            P = P/max(torch.sum(P), 1)
        else:
            P = torch.ones(M.shape[0]).to(M.device) / max(M.shape[0], 1)
            
        # M = M.cpu().to_sparse_csr()
        for i in range(n_quantum_xy):
            P = torch.matmul(M, P)
            mat = to_dense_adj(idx.cpu(), edge_attr=P.cpu().clone(), max_num_nodes=data.num_nodes).squeeze(0)
            mat = mat + mat.t()
            if np.isnan(mat.numpy()).any():
                np.save('results/trash/edges.npy', edge_index.numpy())
                np.save('results/trash/M.npy', M.cpu().numpy())
                mat = torch.zeros_like(mat)
            pe_list.append(mat.clone()*max(num_nodes/2, 1)) # Normalize the XY matrix

    pe = torch.stack(pe_list, dim=-1) # n x n x k

    abs_pe = pe.diagonal().transpose(0, 1) # n x k

    rel_pe = SparseTensor.from_dense(pe, has_value=True)
    rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
    rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)

    if spd:
        spd_idx = walk_length - torch.arange(walk_length)
        val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0)
        val = torch.argmax(val, dim=-1)
        rel_pe_val = F.one_hot(val, walk_length).type(torch.float)
        abs_pe = torch.zeros_like(abs_pe)

    data = add_node_attr(data, abs_pe, attr_name=attr_name_abs)
    data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index")
    data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val")

    #todo : fix automatically the maximum dimension
    # G = graph_from_edge_index(data,37)
    # data = add_node_attr(data, common_num_tensor(G), attr_name="num_common")
    # data = add_node_attr(data, dist_exp_tensor(G), attr_name="dist")

    data.log_deg = torch.log(deg + 1)
    data.deg = deg.type(torch.long)

    

    # Compute the number of common neighbors:
    
    return data

