import torch.nn as nn
import torch
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym import cfg
import torch_geometric.graphgym.register as register
import torch.nn.functional as F
import torch_scatter
from torch_geometric.nn.norm import PairNorm

from graphgps.layer.equivariant_ds import EquivariantDS
from torch_geometric.utils import to_dense_batch


class GCNConvLayer(nn.Module):
    """Graph Isomorphism Network with Edge features (GINE) layer.
    """
    def __init__(self, in_dim, out_dim, dropout, dropout_global, pooling_layer, 
                 residual, add_layer_pooling, add_feedforward, 
                 add_norm_weighting, act='relu',
                 equivstable_pe=False, **kwargs):
        super().__init__()
        self.dim_in = in_dim
        self.dim_out = out_dim
        self.dropout = dropout
        self.global_dropout = nn.Dropout(dropout_global)
        self.residual = residual

        self.pooling_layer = pooling_layer
        self.add_layer_pooling = add_layer_pooling

        
        self.add_feedforward = add_feedforward
        self.add_norm_weighting = add_norm_weighting

        if self.add_layer_pooling == 'learnable_mean':
            self.layer_norm = pyg_nn.Linear(in_dim, in_dim, bias=True)

        if self.add_layer_pooling == 'pairnorm':
            self.pair_norm = PairNorm()

        if self.add_layer_pooling == 'VN':
            #self.self_attn = EquivariantDS(in_dim, 1, reduction='mean', nonlinear='relu')
            self.self_attn = EquivariantDS(in_dim, 1, reduction='mean', nonlinear='relu')

        if self.add_feedforward:
            # Feed Forward block.
            self.activation = F.relu
            self.ff_linear1 = nn.Linear(in_dim, in_dim * 2)
            self.ff_linear2 = nn.Linear(in_dim * 2, in_dim)
            self.norm2 = nn.BatchNorm1d(in_dim)
            self.ff_dropout1 = nn.Dropout(dropout)
            self.ff_dropout2 = nn.Dropout(dropout)

        self.act = nn.Sequential(
            register.act_dict[cfg.gnn.act](),
            nn.Dropout(self.dropout),
        )
        self.model = pyg_nn.GCNConv(self.dim_in, self.dim_out, bias=True)

    def _ff_block(self, x):
        """Feed Forward block.
        """
        x = self.ff_dropout1(self.activation(self.ff_linear1(x)))
        return self.ff_dropout2(self.ff_linear2(x))

    def forward(self, batch):
        x_in = batch.x

        batch.x = self.model(batch.x, batch.edge_index)
        batch.x = self.act(batch.x)

        x = batch.x

        if self.add_layer_pooling and self.add_layer_pooling != 'JK':
            indices = batch.batch
            if self.add_layer_pooling == 'subtract_mean':
                x = x - torch_scatter.scatter(x, indices, dim = 0, reduce=self.pooling_layer)[indices]
            elif self.add_layer_pooling == 'add_mean':
                x = x + torch_scatter.scatter(x, indices, dim = 0, reduce=self.pooling_layer)[indices]
            elif self.add_layer_pooling == 'pairnorm':
                x = self.pair_norm(x, indices)
            elif self.add_layer_pooling == 'VN':
                x_dense, mask = to_dense_batch(x, indices)
                x_global = self.self_attn(x_dense, mask=mask)[mask]
                x = x + x_global 
            elif self.add_layer_pooling == 'learnable_mean':
                x_global = self.layer_norm(x)
                if self.add_norm_weighting:
                    sizes = torch.diff(batch.ptr)
                    deg_inv_sqrt = torch.pow(sizes, -0.5)
                    norm = deg_inv_sqrt
                    x_global = (norm * torch_scatter.scatter(x_global, indices, dim = 0, reduce=self.pooling_layer).T).T[indices]
                else:
                    x_global = torch_scatter.scatter(x_global, indices, dim = 0, reduce=self.pooling_layer)[indices]
                
                x_global = self.global_dropout(x_global)
                x = x + x_global    

        if self.residual:
            batch.x = x_in + x  # residual connection

        if self.add_feedforward:
            batch.x = batch.x + self._ff_block(batch.x)
            batch.x = self.norm2(batch.x)

        if self.add_layer_pooling == 'JK':
            if 'layer_values' not in batch:
                #batch.layer_values = [x_in]
                batch.layer_values = [batch.x]
                #batch.layer_values.append(batch.x)
            else:
                batch.layer_values.append(batch.x)

        return batch
