import torch
import torch.nn as nn
from torch_geometric.nn import GINConv, GINEConv, global_add_pool


class GINEEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=None, num_layers=3, edge_dim=None):
        super().__init__()
        if output_dim is None:
            output_dim = hidden_dim

        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        self.layers.append(
            GINEConv(
                nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_dim)
                ),
                train_eps=True,
                edge_dim=edge_dim
            )
        )

        for _ in range(num_layers - 2):
            self.layers.append(
                GINEConv(
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ReLU(),
                        nn.BatchNorm1d(hidden_dim)
                    ),
                    train_eps=True,
                    edge_dim=edge_dim
                )
            )
        
        self.layers.append(
            GINEConv(
                nn.Sequential(
                    nn.Linear(hidden_dim, output_dim),
                    nn.ReLU(),
                    nn.BatchNorm1d(output_dim)
                ),
                train_eps=True,
                edge_dim=edge_dim
            )
        )
        
    def forward(self, x, edge_index, edge_attr, batch, node_mask, wo_pool=False):
        h = x * node_mask.view(-1, 1)
        for layer in self.layers:
            h = layer(h, edge_index, edge_attr)
            h = h * node_mask.view(-1, 1)

        if wo_pool:
            return h
        else:
            out = global_add_pool(h, batch)
            return out


class GINEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=3):
        super().__init__()
        
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        self.layers.append(
            GINConv(
                nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_dim)
                ),
                train_eps=True
            )
        )

        for _ in range(num_layers - 2):
            self.layers.append(
                GINConv(
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ReLU(),
                        nn.BatchNorm1d(hidden_dim)
                    ),
                    train_eps=True
                )
            )
        
        self.layers.append(
            GINConv(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_dim)
                ),
                train_eps=True
            )
        )
        
    def forward(self, x, edge_index, batch, node_mask, wo_pool=False):
        h = x * node_mask.view(-1, 1)
        for layer in self.layers:
            h = layer(h, edge_index)
            h = h * node_mask.view(-1, 1)

        if wo_pool:
            return h
        else:
            out = global_add_pool(h, batch)
            return out
        
    
class GNNSubgraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=3, edge_dim=None):
        super().__init__()
        
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        self.layers.append(
            GINEConv(
                nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_dim)
                ),
                train_eps=True,
                edge_dim=edge_dim
            )
        )

        for _ in range(num_layers - 2):
            self.layers.append(
                GINEConv(
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ReLU(),
                        nn.BatchNorm1d(hidden_dim)
                    ),
                    train_eps=True,
                    edge_dim=edge_dim
                )
            )
        
        self.layers.append(
            GINEConv(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_dim)
                ),
                train_eps=True,
                edge_dim=edge_dim
            )
        )
        
    def forward(self, x, edge_index, edge_attr, batch, subgraph_batch, node_mask):
        h = x * node_mask.view(-1, 1)
        for layer in self.layers:
            h = layer(h, edge_index, edge_attr)
        
        out = torch.zeros(node_mask.shape[0], node_mask.shape[1], h.shape[-1]).to(h.device) # (B, N, hidden_dim)
        for b in range(len(subgraph_batch)):
            pooled = global_add_pool(h[batch==b], subgraph_batch[b]) # (num_subgraphs, hidden_dim)
            out[b, subgraph_batch[b]] = pooled[subgraph_batch[b]]
        return out
    
    
