# gnn-meta-graph/src/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_max_pool, global_mean_pool

"""
model.py

Contains PyTorch modules used in the GNN fusion framework, including:
- Weaker_First: First-level TransformerConv-based feature extractor.
- Weaker_Middle: Deeper representation refinement layer.
- GraphLevelPredictor: Graph classification head over meta-graphs.
- GTN: Standalone graph transformer network with node/graph embedding output.
"""


class Weaker_First(torch.nn.Module):
    def __init__(self, num_nodes, num_node_features, num_classes):
        super().__init__()
        self.num_nodes = num_nodes  # Store number of nodes
        self.conv0 = nn.Conv1d(num_nodes, num_nodes, 3, 1, padding=1)
        self.conv1 = TransformerConv(num_node_features, 32, heads=4, concat=True, dropout=0.1)
        self.conv2 = TransformerConv(32 * 4, 128, heads=4, concat=True, dropout=0.1)
        self.conv3 = TransformerConv(128 * 4, 32, heads=4, concat=True, dropout=0.1)

        self.local_size = min(7, num_nodes)  # Ensure valid local size
        self.final_feature_dim = 32 * 4  # Ensure correct output size

        self.lr1 = nn.Linear(self.final_feature_dim, 128)
        self.lr2 = nn.Linear(128, 32)
        self.lr3 = nn.Linear(32, num_classes)
        self.attention_scores = None
        
    def forward(self, x, edge_index, batch=None):
        x = self.conv0(x)
        x, attn1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x, attn2 = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x, attn3 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        lap = x  # Store intermediate representation

        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        x = global_max_pool(x, batch)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.lr1(x))
        x = F.relu(self.lr2(x))
        x = self.lr3(x)
        
        self.attention_scores = attn1[1] + attn2[1] + attn3[1]
        return [F.log_softmax(x, dim=1), lap, x]


class Weaker_Middle(torch.nn.Module):
    def __init__(self, input_dim, num_classes, heads=4, dropout=0.1):  
        super().__init__()
        self.input_dim = input_dim  # Dynamically set input feature size
        self.num_classes = num_classes  

        self.conv1 = TransformerConv(128, 32, heads=heads, concat=True, dropout=dropout)  
        
        self.hidden_dim = 32 * heads  # Ensure proper feature size
        self.fc_input_dim = self.hidden_dim  # Update based on actual feature size
        
        self.lr1 = nn.Linear(self.fc_input_dim, 128)
        self.lr2 = nn.Linear(128, 48)
        self.lr3 = nn.Linear(48, self.num_classes)  

        self.attention_scores = None  

    def forward(self, x, edge_index, return_node_embeddings=False, batch=None):
        x, attention1 = self.conv1(x, edge_index, return_attention_weights=True)  
        self.attention_scores = attention1[1]  
        lap = x  
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        # Pooling over nodes
        x = global_max_pool(x, batch)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.lr1(x))
        x = F.relu(self.lr2(x))
        x = self.lr3(x)

        if return_node_embeddings:
            return x  

        return [F.log_softmax(x, dim=1), lap, x]


class GraphLevelPredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=5, heads=4, dropout=0.1):
        super(GraphLevelPredictor, self).__init__()
        self.conv1 = TransformerConv(input_dim, hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = TransformerConv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout)

        self.fc = nn.Linear(hidden_dim * heads, output_dim)  

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))

        x = global_mean_pool(x, batch)  # Aggregates node embeddings into graph embeddings

        return self.fc(x)


class GTN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes, heads=4, dropout=0.1):
        super(GTN, self).__init__()
        self.conv1 = TransformerConv(num_node_features, hidden_channels, heads=heads, dropout=dropout)
        self.conv2 = TransformerConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout)
        self.conv3 = TransformerConv(hidden_channels * heads, 32, heads=1, dropout=dropout)

        self.batch_norm = nn.LayerNorm(32)  
        self.fc = nn.Linear(32, num_classes)
        self.dropout = nn.Dropout(p=dropout)
        self.attention_scores = None
        
    def forward(self, data, return_node_embeddings=False):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x, attention1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = self.dropout(x) if self.training else x
        x, attention2 = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = self.dropout(x) if self.training else x
        x, attention3 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        x = self.batch_norm(x)  

        if return_node_embeddings:
            return x  # Return per-node embeddings

        # Aggregate node embeddings into graph-level representation
        x = global_mean_pool(x, batch)

        self.attention_scores = attention1[1] + attention2[1] + attention3[1]
        
        return self.fc(x)
