# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super(GraphEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class ADGCLModel(nn.Module):
    def __init__(self, encoder, proj_hidden_dim, proj_out_dim, encoder_out_dim):
        super(ADGCLModel, self).__init__()
        self.encoder = encoder
        self.projection_head = nn.Sequential(
            nn.Linear(encoder_out_dim, proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_out_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        h = self.encoder(x, edge_index, edge_attr)
        z = self.projection_head(h)
        return z

# ====== Augmenter ======

class AugmenterEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super().__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.0)
        self.conv2 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.0)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = self.conv2(x, edge_index, edge_attr)
        return x  # [N, out_channels]

class EdgeDropAugmenter(nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=64, num_heads=4, mlp_hidden=64):
        super().__init__()
        self.encoder = AugmenterEncoder(in_channels, hidden_channels, out_channels, num_heads)
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * out_channels, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, 1)  
        )

    def forward(self, x, edge_index, edge_attr, temperature=0.5):
        h = self.encoder(x, edge_index, edge_attr)          # [N, d]
        src, dst = edge_index[0], edge_index[1]
        h_edge = torch.cat([h[src], h[dst]], dim=-1)        # [E, 2d]
        logits = self.edge_mlp(h_edge).squeeze(-1)          # [E]
        # Concrete: p_drop = sigmoid((logit + g)/temp), g ~ Logistic(0,1)
        u = torch.rand_like(logits)
        g = torch.log(u + 1e-12) - torch.log(1 - u + 1e-12)
        p_drop = torch.sigmoid((logits + g) / max(temperature, 1e-3)).unsqueeze(-1)  # [E,1]
        keep_mask = 1.0 - p_drop
        return p_drop, keep_mask
