# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super(GraphEncoder, self).__init__()
        self.out_channels = out_channels
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class GraphCLModel(nn.Module):
    def __init__(self, encoder, proj_hidden_dim, proj_out_dim):
        super(GraphCLModel, self).__init__()
        self.encoder = encoder
        self.projection_head = nn.Sequential(
            nn.Linear(encoder.out_channels, proj_hidden_dim),  # Match this to the encoder output dimension
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_out_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        h = self.encoder(x, edge_index, edge_attr)
        z = self.projection_head(h)
        return z
