import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, BatchNorm
from ogb.graphproppred.mol_encoder import BondEncoder
from models.norms import Normalization
from models.model_utils import get_activation_function
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree

GRADIENT_CHECKPOINTING = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class ZINCGINConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(ZINCGINConv, self).__init__(aggr="add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(in_channels, out_channels), 
                                       torch.nn.BatchNorm1d(out_channels), 
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(out_channels, out_channels))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = torch.nn.Embedding(4, in_channels)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr.squeeze())
        out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out


class GINConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GINConvLayer, self).__init__()
        mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, out_channels),
            torch.nn.BatchNorm1d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels)
        )
        self.layer = GINConv(nn=mlp, train_eps=False)

    def forward(self, x, edge_index, edge_attr):
        return self.layer(x, edge_index)