import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_max
from .utils import AtomEncoder, BondEncoder


class GINConv_GRIN(MessagePassing):
    def __init__(self, hidden_size, output_size=None):
        super().__init__(aggr=None) 
        if output_size is None:
            output_size = hidden_size
        self.f_agg = torch.nn.Sequential(
            torch.nn.Linear(2*hidden_size, 2*hidden_size),
            torch.nn.BatchNorm1d(2*hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(2*hidden_size, output_size),
        )
        self.f_up = torch.nn.Sequential(
            torch.nn.Linear(2*hidden_size, 2*hidden_size),
            torch.nn.BatchNorm1d(2*hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(2*hidden_size, output_size),
        )
        self.bond_encoder = BondEncoder(hidden_size=output_size)

    def forward(self, x, edge_index, edge_attr):
        edge_emb = self.bond_encoder(edge_attr)
        agg = self.propagate(edge_index, x=x, edge_attr=edge_emb)
        out = self.f_up(torch.cat([agg, x], dim=1))
        return out

    def message(self, x_j, edge_attr):
        return self.f_agg(torch.cat([x_j, edge_attr], dim=1))

    def aggregate(self, inputs, index, dim_size=None):
        if not _has_torch_scatter or scatter_max is None:
            raise ImportError("GRIN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.")

        out, _ = scatter_max(inputs, index, dim=0, dim_size=dim_size)
        out[out == float('inf')] = 0.0
        return out

class GCNConv_GRIN(MessagePassing):
    def __init__(self, hidden_size, output_size=None):
        super().__init__(aggr=None)  
        if output_size is None:
            output_size = hidden_size
        self.linear = torch.nn.Linear(hidden_size, output_size)
        self.f_agg = torch.nn.Sequential(
            torch.nn.Linear(2*hidden_size, 2*hidden_size),
            torch.nn.BatchNorm1d(2*hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(2*hidden_size, output_size),
        )
        self.f_up = torch.nn.Sequential(
            torch.nn.Linear(2*hidden_size, 2*hidden_size),
            torch.nn.BatchNorm1d(2*hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(2*hidden_size, output_size),
        )
        self.bond_encoder = BondEncoder(hidden_size=output_size)

    def forward(self, x, edge_index, edge_attr):
        x_lin = self.linear(x)

        row, col = edge_index
        deg = degree(row, x_lin.size(0), dtype=x_lin.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        edge_emb = self.bond_encoder(edge_attr)
        agg = self.propagate(
            edge_index,
            x=x_lin,
            edge_attr=edge_emb,
            norm=norm,       
            size=None
        )
        return self.f_up(torch.cat([agg, x_lin], dim=1))

    def message(self, x_j, edge_attr, norm):
        m = self.f_agg(torch.cat([x_j, edge_attr], dim=1))
        return norm.view(-1, 1) * F.relu(m)

    def aggregate(self, inputs, index, dim_size=None):
        if not _has_torch_scatter or scatter_max is None:
            raise ImportError("GRIN requires `torch_scatter` package. Please install it via `pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html`.")

        out, _ = scatter_max(inputs, index, dim=0, dim_size=dim_size)
        out[out == float('inf')] = 0.0
        return out

### GNN to generate node embedding
class GNN_node(torch.nn.Module):
    """
    Output:
        node representations
    """
    def __init__(self, num_layer, emb_dim, drop_ratio = 0.1, JK = "last", residual = False, gnn_type = 'gcn'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers

        '''

        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layer):
            if gnn_type == 'gin':
                self.convs.append(GINConv_GRIN(emb_dim))
            elif gnn_type == 'gcn':
                self.convs.append(GCNConv_GRIN(emb_dim))
            else:
                raise ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(torch.nn.LayerNorm(emb_dim))
            
    def forward(self, batched_data):

        ### computing input node embedding
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layer):

            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if self.residual:
                h += h_list[layer]

            h_list.append(h)
        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation
