# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=8):
        """
        Encoder with 5 stacked TransformerConv layers
        """
        super(GraphEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class CustomDecoder(nn.Module):
    def __init__(self, out_channels, num_features):
        """
        Define linear layers for node and edge reconstruction
        """
        super(CustomDecoder, self).__init__()
        self.node_recon = nn.Linear(out_channels, num_features)
        self.edge_recon = nn.Linear(out_channels, 1)

    def forward(self, z, edge_index):
        row, col = edge_index
        # Reconstruction for each node
        node_features_recon = self.node_recon(z)
        # Reconstruction using product of two nodes for each edge
        edge_features_recon = self.edge_recon(z[row] * z[col]).squeeze(-1)
        return node_features_recon, edge_features_recon

class GraphModel(nn.Module):
    def __init__(self, encoder, decoder):
        """
        Graph AutoEncoder (GAE) model combining encoder and decoder
        """
        super(GraphModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, data):
        z = self.encoder(data.x, data.edge_index, data.edge_attr)
        node_features_recon, edge_features_recon = self.decoder(z, data.edge_index)
        return z, node_features_recon, edge_features_recon
