import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        """
        5 stacked TransformerConv layers
        """
        super(GraphEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class SupGCLModel(torch.nn.Module):
    def __init__(self, in_c, hid, out, heads=8, proj_h=64, proj_out=64):
        super().__init__()
        self.encoder = GraphEncoder(in_c, hid, out, heads)
        self.head = torch.nn.Sequential(
            torch.nn.Linear(out, proj_h), torch.nn.ReLU(),
            torch.nn.Linear(proj_h, proj_out)
        )
    def forward(self, x, edge_index, edge_attr):
        h = self.encoder(x, edge_index, edge_attr)
        z = self.head(h)
        return F.normalize(z, p=2, dim=1)
