# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv
from torch_geometric.data import Data
from utils import gumbel_softmax

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super(GraphEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class AutoGCLModel(nn.Module):
    def __init__(self, encoder, proj_hidden_dim, proj_out_dim, encoder_out_dim)
        super(AutoGCLModel, self).__init__()
        self.encoder = encoder
        self.projection_head = nn.Sequential(
            nn.Linear(encoder_out_dim, proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_out_dim)
        )

    def forward(self, x, edge_index, edge_attr):
        h = self.encoder(x, edge_index, edge_attr)
        z = self.projection_head(h)
        return z

class ViewGenerator(nn.Module):
    def __init__(self, in_channels: int, hidden: int = 64, num_heads: int = 1, edge_dim: int = 1, tau: float = 1.0):
        super().__init__()
        self.feat = TransformerConv(in_channels, hidden, heads=num_heads, edge_dim=edge_dim, dropout=0.0)
        self.mlp  = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden * num_heads, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 3)  # keep/mask/drop
        )
        self.tau = tau 
        
    def forward(self, data: Data):
        x, ei, ea = data.x, data.edge_index, data.edge_attr
        h = self.feat(x, ei, ea)
        logits = self.mlp(h)                      # (N,3)
        A = gumbel_softmax(logits, self.tau, hard=True)  # (N,3) one-hot

        keep_mask = A[:, 0:1]  # keep
        drop_mask = A[:, 2:3]  # drop

        x_masked = x * keep_mask  

        keep_nodes = (drop_mask.squeeze(-1) < 0.5)  # (N,)
        idx_map = torch.full((x.size(0),), -1, dtype=torch.long, device=x.device)
        new_idx = torch.arange(keep_nodes.sum(), device=x.device)
        idx_map[keep_nodes] = new_idx

        src, dst = ei[0], ei[1]
        ekeep = keep_nodes[src] & keep_nodes[dst]
        new_ei = torch.stack([idx_map[src[ekeep]], idx_map[dst[ekeep]]], dim=0)
        new_ea = ea[ekeep] if ea is not None else None
        if new_ea is not None and new_ea.dim() == 1:
            new_ea = new_ea.unsqueeze(-1)

        aug = Data(x=x_masked[keep_nodes], edge_index=new_ei, edge_attr=new_ea)
        return aug, A, keep_nodes, idx_map