import torch
import torch.nn as nn

from torch_geometric.nn import MessagePassing
from torch_geometric.nn.dense.linear import Linear


class GNNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr, **kwargs):
        super().__init__(aggr, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin_l = Linear(in_channels, out_channels, bias=True)
        self.lin_r = Linear(in_channels, out_channels, bias=False)

        self.lin_edge_l = Linear(in_channels, out_channels, bias=True)
        self.lin_edge_r = Linear(in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.lin_edge_l.reset_parameters()
        self.lin_edge_r.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        x = (x, x)

        out = self.propagate(edge_index, x=x, xe=edge_attr)

        x_r = x[1]
        h_update = self.lin_l(out) + self.lin_r(x_r) 

        row, col = edge_index
        h_u, h_v = x_r[row], x_r[col]
        e_update = self.lin_edge_l((h_u + h_v).relu()) * 0.5 + self.lin_edge_r(edge_attr)

        return h_update, e_update

    def message(self, x_j, xe):
        return (x_j + xe).relu()
    

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, activation, num_layers, normalize, dropout):
        super(Encoder, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.normalize = normalize

        self.activation = activation()
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)

        dims = [input_dim] + [hidden_dim] * num_layers

        for in_dim, out_dim in zip(dims[:-1], dims[1:]):
            self.layers.append(GNNConv(in_dim, out_dim, aggr='mean'))
            self.norms.append(nn.BatchNorm1d(out_dim))

        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        h = x
        e = edge_attr
        for i, conv in enumerate(self.layers):
            h, e = conv(h, edge_index, e)
            if self.normalize != 'none':
                h = self.norms[i](h)
                e = self.norms[i](e)
            if i < self.num_layers - 1: 
                h = self.activation(h)
                h = self.dropout(h)
                e = self.activation(e)
                e = self.dropout(e)
        return h


class InnerProductDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.lin = nn.Linear(hidden_dim, output_dim)

    def forward(self, z, edge_index, sigmoid):
        z = self.lin(z)
        value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
        return torch.sigmoid(value) if sigmoid else value
