
import torch
from torch.nn.functional import relu, sigmoid, leaky_relu
from torch.nn import Linear

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree

class MPBiSto(MessagePassing):
    """
    Bi-stochastic message passing
    propagation matrix
    P = Id - (D-A)/maxdeg
    else
    P = D^(-1) A
    """
    def __init__(self, bistochastic=False):
        super().__init__(aggr='add')
        self.bistochastic = bistochastic

    def forward(self, x, edge_index):
        
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        maxdeg = torch.max(deg)
        
        if self.bistochastic:
            out = (1-deg/maxdeg)[:,None] * x + self.propagate(edge_index, x=x)/maxdeg #, norm=norm)
        else:
            out = (self.propagate(edge_index, x=x)) / (deg[:, None])
        return out

    def message(self, x_j):
        return x_j 

class GNNBiSto(torch.nn.Module):
    
    def __init__(self, num_node_features, num_classes, num_layers=4,
                 num_units=32, activation='relu', bistochastic = False,
                 rec_intermediate_grad=False, dilation_factor=1,
                 is_id=False, std=0.01, is_MLP = False, skip=False, scale=1): 
        """
        GNNs with bi-stochastic propagation matrices.

        Parameters
        ----------
        num_node_features : int
            Input dimension.
        num_classes : int
            output dimension.
        num_layers : int, optional
            Number of layers. The default is 4.
        num_units : int, optional
            Internal dimension. The default is 32.
        activation : 'relu', 'sigmoid', 'linear', 'leaky_relu', optional
            Activation function. The default is 'relu'.
        rec_intermediate_grad : bool, optional
            If True, records signals and gradients at each layer. The default is False.
        is_id : bool, optional
            If true, weights are initialized to identity + small noise. The default is False.
        std : float > 0, optional
            std of noise if is_id is True. The default is 0.01
        is_MLP : bool, optional
            If True, the GNN is in fact an MLP (identity propagation matrix). The default is False.
        skip : bool, optional
            If True, skip connections. The default is False.
        scale : scalar, scaling the weights

        """
        super().__init__()
        self.num_layers = num_layers
        self.num_units = num_units
        self.is_id = is_id
        self.is_MLP = is_MLP
        self.bistochastic = bistochastic
        self.std = std
        self.rec_intermediate_grad = rec_intermediate_grad
        self.MP = MPBiSto(bistochastic = bistochastic)
        self.lin_layers = torch.nn.ModuleList()
        self.activation = activation
        self.skip=skip
        self.scale=scale
        self.Xs, self.Fs, self.Hs = [], [], []
        if(num_layers == 1):
            self.lin_layers.append(Linear(num_node_features, num_classes, bias=False))
        else:
            self.lin_layers.append(Linear(num_node_features, self.num_units, bias=False))
            for _ in range(num_layers-2):
                self.lin_layers.append(Linear(self.num_units, self.num_units, bias=False))
            self.lin_layers.append(Linear(self.num_units, num_classes, bias=False))

        self.standard_init(is_id=is_id, std=std, scale=scale)

    def forward(self, data, device=None, cut_layer=None):
        Fs, Hs, Xs = [], [], []
        X, edge_index = data.x, data.edge_index
        X.requires_grad=True
        
        for layer in range(self.num_layers): #loop_size):
            
            if(self.rec_intermediate_grad):
                X.retain_grad()
                Xs.append(X)
            
            if not self.is_MLP:
                F = self.MP(X, edge_index)
            else:
                F = X    
            H = self.lin_layers[layer](F)
            
            if(self.activation == 'relu'):
                XX = relu(H)
            elif(self.activation == 'sigmoid'):
                XX = sigmoid(H)
            elif(self.activation == 'leaky_relu'):
                XX = leaky_relu(H)
            else:
                XX = H
            if self.skip and layer > 0 and layer < self.num_layers-1:
                X = X + XX
            else:
                X = XX
            # recording intermediate signals/gradients
            if(self.rec_intermediate_grad):
                F.retain_grad()
                Fs.append(F)
                H.retain_grad()
                Hs.append(H)
            
        self.Xs, self.Fs, self.Hs = Xs, Fs, Hs
        return H

    def standard_init(self, is_id=False, std=0.01, scale=1):
        True
        for (i, layer) in enumerate(self.lin_layers):
            if self.is_id and i!=0 and i!=len(self.lin_layers)-1:
                # unusual init to avoid vanishing signal: identity + small noise
                layer.weight = torch.nn.Parameter(torch.eye(self.num_units)+
                                                  std*torch.randn((self.num_units,self.num_units)))
            else:
                torch.nn.init.kaiming_normal_(layer.weight, nonlinearity=self.activation)
            layer.weight.data = layer.weight.data*scale

