import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_mean_pool

class TransformerEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=64, num_heads=8, dropout=0.1):
        super().__init__()
        self._init_args = (in_channels, hidden_channels, out_channels, num_heads)
        self._init_kwargs = {'dropout': dropout}
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=dropout)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=dropout)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=dropout)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=dropout)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=dropout)

    def forward(self, x, edge_index, edge_attr=None):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x  # node-level representation

def copy_and_perturb_encoder(encoder: TransformerEncoder, eta: float):
    pert = TransformerEncoder(*encoder._init_args, **encoder._init_kwargs).to(next(encoder.parameters()).device)
    pert.load_state_dict(encoder.state_dict())
    # ensure perturbed encoder does not contribute gradients
    for p in pert.parameters():
        p.requires_grad = False
    for p_orig, p_pert in zip(encoder.parameters(), pert.parameters()):
        std = p_orig.std().item()
        if std <= 0:
            std = 1e-6
        noise = torch.normal(0, std * eta, size=p_orig.shape, device=p_orig.device)
        p_pert.data = p_orig.data + noise
    return pert

class SimGRACEModel(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, out_dim=64, num_heads=8, eta=1.0):
        super().__init__()
        self.student = TransformerEncoder(in_dim, hidden_dim, out_dim, num_heads)
        self.eta = eta
        self.projector = nn.Identity()  # place to swap in MLP if desired

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        # node-level embeddings
        z1_nodes = self.student(x, edge_index, edge_attr)  # [N_total, D]
        perturbed_encoder = copy_and_perturb_encoder(self.student, self.eta)
        z2_nodes = perturbed_encoder(x, edge_index, edge_attr)  # [N_total, D] (no grad path)

        # graph-level pooling per graph in batch
        g1 = global_mean_pool(self.projector(z1_nodes), batch)  # [B, D]
        g2 = global_mean_pool(self.projector(z2_nodes), batch)  # [B, D]
        return g1, g2

