import torch
import torch.nn as nn
import torch.nn.functional as F


class CausalEdgeIntervener(nn.Module):
    def __init__(self, feature_dim, num_prototypes=32, hidden_dim=64):
        super(CausalEdgeIntervener, self).__init__()

        self.confounder_dict = nn.Parameter(torch.randn(num_prototypes, feature_dim))

        self.query_layer = nn.Linear(feature_dim, hidden_dim)
        self.key_layer = nn.Linear(feature_dim, hidden_dim)

        self.gate_layer = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
        nn.init.constant_(self.gate_layer[-2].bias, -1.0)

        nn.init.xavier_uniform_(self.confounder_dict)

    def forward(self, edge_features, edge_logvar):
        edge_sigma = torch.exp(0.5 * edge_logvar).mean(dim=1, keepdim=True)
        alpha = self.gate_layer(edge_sigma)

        Q = self.query_layer(edge_features)
        K = self.key_layer(self.confounder_dict)

        attention_scores = torch.matmul(Q, K.transpose(0, 1))
        attention_weights = F.softmax(attention_scores, dim=-1)

        context_expectation = torch.matmul(attention_weights, self.confounder_dict)

        intervened_features = edge_features + alpha * context_expectation

        return intervened_features, alpha

