from torch import Tensor
import torch.nn as nn
from torch_scatter import scatter
from torch_geometric.utils import to_undirected
from torch_geometric.utils import degree
from torch_geometric.typing import OptTensor
import numpy as np
import scipy.sparse as sp
import torch

from .utils import sp_mtx_to_sp_tnsr

def TransposeNorm(adj):
    return adj+adj.T

@torch.no_grad()
def LaplaceNorm(adj):
    if isinstance(adj,torch.Tensor):
        with torch.no_grad():
            I = sp_mtx_to_sp_tnsr(sp.eye(adj.shape[0]))
            d = torch.sparse.sum(adj,1).to_dense()
            d_inv_sqrt = torch.pow(d,-0.5)
            d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
            d_inv_sqrt = torch.diag(d_inv_sqrt)
            return  .5*(d_inv_sqrt @ (torch.diag(d)-adj) @ d_inv_sqrt).to_sparse()
    else:
        adj = sp.coo_matrix(adj)
        row_sum = np.array(adj.sum(1))
        d_inv_sqrt = np.power(row_sum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        row_sum = sp.diags(row_sum.flatten())
        return .5*d_mat_inv_sqrt.dot(row_sum-adj).dot(d_mat_inv_sqrt).tocoo()

def RWNorm(adj):
    if isinstance(adj,torch.Tensor):
        with torch.no_grad():
            I = torch.eye(adj.shape[0],device=adj.device)
            d_inv = torch.pow(torch.sparse.sum(adj,1),-1).to_dense()
            d_inv[torch.isinf(d_inv)] = 0.
            return  (I-torch.spmm(adj.t(),torch.diag(d_inv).t()).t()).to_sparse()
    else:
        adj = sp.coo_matrix(adj)
        row_sum = np.array(adj.sum(1))
        d_inv_sqrt = np.power(row_sum, -1).flatten()
        d_mat_inv= sp.diags(d_inv_sqrt)
        I = sp.eye(adj.shape[0])
        return I - d_mat_inv.dot(adj)

def IdentNorm(adj):
    return adj

def DiagNorm(adj):
    return adj + sp.eye(adj.shape[0])

def AugNorm(adj, need_orig=False):
    if isinstance(adj,torch.Tensor):
        with torch.no_grad():
            I = sp_mtx_to_sp_tnsr(sp.eye(adj.shape[0])).to(adj.device)
            adj = adj + I #A+I
            d = torch.sparse.sum(adj,1).to_dense() #Dt
            d_inv_sqrt = torch.pow(d,-0.5) #Dt ^-.5
            d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
            # d_inv_sqrt = torch.diag(d_inv_sqrt)
            diag = d_inv_sqrt.cpu().numpy()
            d_inv_sqrt = sp_mtx_to_sp_tnsr(sp.diags(diag)).to(adj.device)
            support = torch.sparse.mm(adj,d_inv_sqrt)
            return  torch.sparse.mm(d_inv_sqrt, support) #Dt -.5 (A+I) Dt -.5
    else:
        if not need_orig:
            adj = adj + sp.eye(adj.shape[0])
        adj = sp.coo_matrix(adj)
        row_sum = np.array(adj.sum(1))
        d_inv_sqrt = np.power(row_sum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()

def fetch_normalization(type):
   switcher = {
       'AugNorm': AugNorm,  # S = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
       'DiagNorm' : DiagNorm, #S = A + I
       'IdentNorm' : IdentNorm, #S = A
       'LaplaceNorm' : LaplaceNorm, #S = D^-1/2 (D-A) D^-1/2
       'RWNorm' : RWNorm, #S = I - (D^-1 A)
       'TransposeNorm' : TransposeNorm #S = A + A'
   }
   func = switcher.get(type, lambda: "Invalid normalization technique.")
   return func

def row_normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx



class LayerNorm(nn.Module):
    def __init__(self, in_channels, eps=1e-5, affine=True):
        super().__init__()

        self.in_channels = in_channels
        self.eps = eps

        if affine:
            self.weight = nn.Parameter(torch.empty((in_channels,)))
            self.bias = None
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            if self.weight.size(0) >= 256:
                self.weight.data.fill_(0.5)
            else:
                self.weight.data.fill_(1.)

    def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
        if batch is None:
            out = x / (x.std(unbiased=False) + self.eps)

        else:
            batch_size = int(batch.max()) + 1

            norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
            norm = norm.mul_(x.size(-1)).view(-1, 1)

            var = scatter(x * x, batch, dim=0, dim_size=batch_size,
                          reduce='add').sum(dim=-1, keepdim=True)
            var = var / norm

            out = x / (var + self.eps).sqrt().index_select(0, batch)

        if self.weight is not None:
            out = out * self.weight

        return out


def cal_norm(edge_index, num_nodes=None, self_loop=False, cut=False):
    # calculate normalization factors: (2*D)^{-1/2}
    if num_nodes is None:
        num_nodes = edge_index.max()+1
        
    D = degree(edge_index[0], num_nodes)
    if self_loop:
        D = D + 1
    
    if cut:  # for symmetric adj
        D = torch.sqrt(1/D)
        D[D == float("inf")] = 0.
        edge_index = to_undirected(edge_index, num_nodes=num_nodes)
        row, col = edge_index
        mask = row<col
        edge_index = edge_index[:,mask]
    else:
        D = torch.sqrt(1/2/D)
        D[D == float("inf")] = 0.
    
    if D.dim() == 1:
        D = D.unsqueeze(-1)

    return D, edge_index