# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

class GraphTransformerEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super(GraphTransformerEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, proj_dim, num_heads):
        super(Encoder, self).__init__()
        self.backbone = GraphTransformerEncoder(in_channels, hidden_channels, proj_dim, num_heads)
        self.projection_head = nn.Sequential(
            nn.Linear(proj_dim, proj_dim),
            nn.PReLU(),
            nn.Dropout(0.1),
            nn.Linear(proj_dim, proj_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        h = self.backbone(x, edge_index, edge_attr)
        h_proj = self.projection_head(h)
        return h, h_proj

class Online(nn.Module):
    def __init__(self, online_encoder: Encoder, target_encoder: Encoder, hidden_dim: int, momentum: float):
        super(Online, self).__init__()
        self.online_encoder = online_encoder
        self.target_encoder = target_encoder
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.PReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.momentum = momentum

    def update_target_encoder(self):
        for p_tgt, p_onl in zip(self.target_encoder.parameters(), self.online_encoder.parameters()):
            p_tgt.data = p_tgt.data * self.momentum + p_onl.data * (1 - self.momentum)

    def forward(self, x, edge_index, edge_attr, dense_adj, global_hop=3):
        h_raw, h_proj = self.online_encoder(x, edge_index, edge_attr)
        h2 = h_proj.clone()
        for _ in range(global_hop):
            h2 = torch.matmul(dense_adj, h2)
        h_comb = h_proj + h2
        h_pred = self.predictor(h_comb)
        with torch.no_grad():
            _, h_tgt_proj = self.target_encoder(x, edge_index, edge_attr)
        return h_raw, h_pred, h_tgt_proj

    def get_loss(self, z_pred, z_tgt):
        z_pred = F.normalize(z_pred, dim=-1, p=2)
        z_tgt  = F.normalize(z_tgt,  dim=-1, p=2)
        return -(z_pred * z_tgt).sum(dim=-1).mean()

class Target(nn.Module):
    def __init__(self, target_encoder: Encoder):
        super(Target, self).__init__()
        self.target_encoder = target_encoder

    def forward(self, x, edge_index, edge_attr):
        _, h_proj = self.target_encoder(x, edge_index, edge_attr)
        return h_proj

    def get_loss(self, z):
        z = F.normalize(z, dim=-1, p=2)
        c = z.mean(dim=0)
        return -((z - c).pow(2).sum(dim=-1)).mean()