import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import add_self_loops, degree
from models.model_utils import get_activation_function, initialize_layer_w_zero
from models.norms import Normalization
from torch_geometric.nn import MessagePassing, global_mean_pool, GINConv

#TODO(tm): send in the device while creating the model.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class EdgeWiseLayer(MessagePassing):
    """
        # x = ReLU(W2 * message + b_2) * tanh(W_1 * message + b_1)
        v2:
            x = (W2 * message + b_2) * tanh(W_1 * message + b_1)
    """
    def __init__(self, 
                 in_channels=1,
                 out_channels=1,
                 initialize_w_zero=False,
                 **kwargs,
        ):
        super().__init__(aggr='add')
        
        self.J1 = nn.Linear(in_channels, out_channels)

        if initialize_w_zero:
            self.J1 = initialize_layer_w_zero(self.J1)
        
    def reset_parameters(self):
        self.J1.reset_parameters()
        self.J1.bias.zero_()


    def get_norm(self, edge_index, feature, node_feature=None):
        row, col = edge_index 
        deg = degree(col, feature.size(0), dtype=feature.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return norm
    
    def forward(self, message, edge_index, node_feature, 
                edge_attr=None, edgewise_edge_index=None):
        middle = self.J1(message)

        norm = self.get_norm(edgewise_edge_index, feature=message)

        out = self.propagate(edgewise_edge_index, x=middle, norm=norm)
        x = out + node_feature[edge_index[0]]
        return x


class ZINCGINConv(MessagePassing):
    """Used only for Zinc Dataset since it required the use of edge attributes.
    """
    def __init__(self, in_channels, out_channels, **kwargs):
        super(ZINCGINConv, self).__init__(aggr="add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(in_channels, out_channels), 
                                       torch.nn.BatchNorm1d(out_channels), 
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(out_channels, out_channels))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = torch.nn.Embedding(4, in_channels)

    def forward(self, message, edge_index, 
                node_feature=None, 
                edge_attr=None, # this has to be #edgewise_edge_index indices x 1
                edgewise_edge_index=None):
        edge_embedding = self.bond_encoder(edge_attr.squeeze())
        out = self.mlp(
            (1 + self.eps) * message 
            + self.propagate(edgewise_edge_index, x=message, edge_attr=edge_embedding))

        x = out + node_feature[edge_index[0]]

        return x

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

class EdgewiseGINConv(GINConv):
    """
        # x = ReLU(W2 * message + b_2) * tanh(W_1 * message + b_1)
        v2:
            x = (W2 * message + b_2) * tanh(W_1 * message + b_1)
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, message, edge_index, node_feature, 
                edge_attr=None, edgewise_edge_index=None):
        message = self.nn(message)
        out = self.propagate(edgewise_edge_index, x=message)
        x = out + node_feature[edge_index[0]]
        return x
    
class EdgeGINConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, initialize_w_zero=False, norm_type='Identity'):
        super(EdgeGINConvLayer, self).__init__()
        mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, out_channels),
            Normalization(out_channels, norm_type),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels)
        )
        self.layer = EdgewiseGINConv(nn=mlp, train_eps=False)

    def forward(self, message, edge_index, node_feature, 
                edge_attr=None, edgewise_edge_index=None):
        return self.layer(message, edge_index, node_feature, edgewise_edge_index)