# ------------------------ : new rwpse ----------------
from typing import Union, Any, Optional
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
from torch_scatter import scatter, scatter_add, scatter_max
from torch_geometric.utils.convert import to_networkx

from torch_geometric.graphgym.config import cfg

from torch_geometric.utils import (
    get_laplacian,
    get_self_loop_attr,
    to_scipy_sparse_matrix,
)
import torch_sparse
from torch_sparse import SparseTensor


def add_node_attr(data: Data, value: Any,
                  attr_name: Optional[str] = None) -> Data:
    if attr_name is None:
        if 'x' in data:
            x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
        else:
            data.x = value
    else:
        data[attr_name] = value

    return data


class MyTransform(BaseTransform):


    def __init__(
        self,
        max_dim: int,
    ):
        self.max_dim = max_dim

    
    def forward(self, data: Data) -> Data:
        G = to_networkx(data,to_undirected = True)
        
        # add nodes to graph to uniformize tensors size 
        for node in range(len(G), self.max_dim) : 
            G.add_node(node)
        
        node_list = list(G.nodes())
        num_common =  torch.diag(torch.tensor(list(dict(G.degree()).values())))
        for i in range(len(node_list)):
            for j in range(i+1,len(node_list)): 
                num_common[i][j] += len(list(nx.common_neighbors(G, node_list[i], node_list[j])))
                num_common[j][i] += num_common[i][j]
        
        data.num_common = num_common
        
        # change name of the original x features and add label -1 to isolates 
        data.feat = torch.cat((data.x ,-torch.ones(self.max_dim - data.x.shape[0],1)))
        data.x =  torch.exp(torch.tensor(-nx.floyd_warshall_numpy(G)))
        
        return data





# max_dim = 37
# path = osp.join( '../data', 'ZINC-PE')
# transform = T.Compose([MyTransform(max_dim), T.AddRandomWalkPE(walk_length=20, attr_name='pe')])
# train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)
# val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)
# test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=64)
# test_loader = DataLoader(test_dataset, batch_size=64)

class Correlation_matrix:

    def __init__(self,
                h : float,
                theta: torch.Tensor,
                t: torch.Tensor,
                device : torch.device ,
                 ) -> torch.Tensor:
        super().__init__()
        self.h = h
        self.t = t.to(device)
        self.theta = theta.to(device)
        self.device = device


    def edge_index_to_batch_adj(self, data_batch):
        batched = []
        max_dim = data_batch.num_common.shape[1]
        BN = data_batch.batch.max().item()+1

        for bn in range(BN):
            idx_min = data_batch.ptr[bn]
            idx_max = data_batch.ptr[bn+1]
            Idx = torch.unique(torch.where((idx_min<data_batch.edge_index) & (data_batch.edge_index<idx_max))[1])
            A_batch = torch.zeros(max_dim,max_dim)
            edges = data_batch.edge_index[:,Idx]-idx_min.item()
            A_batch[edges[0], edges[1]]=1
            batched.append(A_batch)
        
        return torch.cat(batched).reshape(BN,max_dim,max_dim).to(self.device)
    

    def w_ij_batched(self, data_batch):
        max_dim = data_batch.num_common.shape[1]
        Adj_batched = self.edge_index_to_batch_adj(data_batch)

        return torch.cos(self.theta[:, :, None])**2 + (torch.sin(self.theta[:, :, None])**2)*\
        torch.exp(Adj_batched*self.t[:, :, None]*1j).to(self.device)
    

    def w_plus_batched(self, data_batch):
        BN = data_batch.batch.max().item()+1
        max_dim = data_batch.num_common.shape[1]
        com_ij = data_batch.num_common.reshape(BN,max_dim,max_dim).to(self.device)

        return torch.pow((torch.cos(self.theta[:, :, None])**2 + (torch.sin(self.theta[:, :, None])**2)*
                          torch.exp(2*self.t[:, :, None]*1j)),com_ij).to(self.device)
    

    def w_minus_batched(self, data_batch):
        BN = data_batch.batch.max().item()+1
        max_dim = data_batch.num_common.shape[1]
        com_ij = data_batch.num_common.reshape(BN,max_dim,max_dim).to(self.device)

        return torch.pow(torch.cos(self.theta[:, :, None])**2, com_ij).to(self.device)
    

    def compute_correlation_matrix_batched(self, data_batch):
        D = data_batch.num_common.shape[1]
        BS = data_batch.batch.max().item()+1
        Adj_batched = self.edge_index_to_batch_adj(data_batch)
        F = ((4*(torch.sin(self.theta)**4) * (torch.cos(self.theta)**4))[:, :, None]).to(self.device)
        W = self.w_ij_batched(data_batch).to(self.device)

        rho_vect = torch.prod(W,2) * torch.exp(self.h*self.t*1j).to(self.device)
        rho_col = ((rho_vect.reshape(BS ,D, 1).repeat(1, 1, D)).reshape(-1, D)).to(self.device)
        rho_row = (torch.repeat_interleave(rho_vect,D,dim = 0)).to(self.device)
        rho_ij = ((rho_col + rho_row).reshape(BS, D, D)).to(self.device)

        f1 = (rho_ij * (1 - 1/W)).to(self.device) #####
        a = (.5*(1-(torch.exp(Adj_batched*self.t[:, :, None]*1j)/self.w_plus_batched(data_batch)))).to(self.device)
        b = ((rho_vect[:,:,None]).repeat(1, 1, D).reshape(BS,D,D)\
                 *torch.repeat_interleave(rho_vect,D,dim = 0 ).reshape(BS, D, D)).to(self.device)
        
        f2 = a * b
        b_conj = ((rho_vect[:,:,None]).repeat(1,1,D).reshape(BS,D,D)\
                 *torch.conj(torch.repeat_interleave(rho_vect,D,dim = 0 ).reshape(BS,D,D))).to(self.device)
        
        f3 = .5 * (1-(1/self.w_minus_batched(data_batch))) * b_conj

        return ((F * torch.real(f1 + f2 + f3)).view(-1,D)).to(self.device)




@torch.no_grad()
def add_full_qcorr(data,
                  walk_length=8,
                  attr_name_abs="qcorr", # name: 'rrwp'
                  attr_name_rel="qcorr", # name: ('rrwp_idx', 'rrwp_val')
                  add_identity=True,
                  spd=False,
                  **kwargs
                  ):
    device=data.edge_index.device
    ind_vec = torch.eye(walk_length, dtype=torch.float, device=device)
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, data.edge_weight

    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                       sparse_sizes=(num_nodes, num_nodes),
                                       )

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj = adj * deg_inv.view(-1, 1)
    adj = adj.to_dense()

    pe_list = []
    i = 0
    if add_identity:
        pe_list.append(torch.eye(num_nodes, dtype=torch.float))
        i = i + 1

    out = adj
    pe_list.append(adj)

    if walk_length > 2:
        for j in range(i + 1, walk_length):
            out = out @ adj
            pe_list.append(out)

    pe = torch.stack(pe_list, dim=-1) # n x n x k

    abs_pe = pe.diagonal().transpose(0, 1) # n x k

    rel_pe = SparseTensor.from_dense(pe, has_value=True)
    rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
    rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)

    if spd:
        spd_idx = walk_length - torch.arange(walk_length)
        val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0)
        val = torch.argmax(val, dim=-1)
        rel_pe_val = F.one_hot(val, walk_length).type(torch.float)
        abs_pe = torch.zeros_like(abs_pe)

    abs_pe = torch.randn(abs_pe.size())
    rel_pe_val = torch.randn(rel_pe_val.size())

    data = add_node_attr(data, abs_pe, attr_name=attr_name_abs)
    data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index")
    data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val")
    data.log_deg = torch.log(deg + 1)
    data.deg = deg.type(torch.long)

    return data

