import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.data import Data
from typing import Dict, List, Optional, Tuple
import warnings
warnings.filterwarnings("ignore")

class MultiRelationalEGNNLayer(nn.Module):
    """Single layer of multi-relational EGNN (EGNN-R)"""
    
    def __init__(self, node_dim, edge_dim, hidden_dim, num_relations=4, 
                 act_fn=nn.SiLU(), residual=True, normalize=False, 
                 coords_agg='mean', tanh=False):
        super().__init__()
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        self.num_relations = num_relations
        self.residual = residual
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.epsilon = 1e-8
        
        # Relation-specific message MLPs
        self.message_mlps = nn.ModuleDict()
        self.coord_mlps = nn.ModuleDict()
        
        for r in range(num_relations):
            # Message MLP: φ_m^(r) 
            # Input: h[row] + h[col] + edge_feat + rbf_dist = node_dim*2 + edge_dim + 16
            self.message_mlps[str(r)] = nn.Sequential(
                nn.Linear(2 * node_dim + edge_dim + 16, hidden_dim),  # +16 for RBF distance
                act_fn,
                nn.Linear(hidden_dim, hidden_dim),
                act_fn
            )
            
            # Coordinate MLP: φ_x^(r) (outputs scalar)
            coord_layers = [
                nn.Linear(hidden_dim, hidden_dim),
                act_fn,
                nn.Linear(hidden_dim, 1, bias=False)
            ]
            if tanh:
                coord_layers.append(nn.Tanh())
            self.coord_mlps[str(r)] = nn.Sequential(*coord_layers)
        
        # Node update MLP: φ_h (shared across relations)
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, hidden_dim),
            act_fn,
            nn.Linear(hidden_dim, node_dim)
        )
        
        # RBF parameters for distance encoding
        self.rbf_centers = nn.Parameter(torch.linspace(0, 20, 16), requires_grad=False)
        self.rbf_width = 1.0
        
    def compute_rbf(self, distances):
        """Compute radial basis function encoding γ(d_ij)"""
        rbf = torch.exp(-0.5 * ((distances.unsqueeze(-1) - self.rbf_centers) / self.rbf_width) ** 2)
        return rbf
    
    def forward(self, h, x, hetero_data, node_type):
        """
        Forward pass of multi-relational EGNN layer
        
        Args:
            h: Node features (num_nodes, node_dim)
            x: Node coordinates (num_nodes, 3) - CA positions
            hetero_data: HeteroData object containing edge information
            node_type: 'ag_res' or 'ab_res'
            
        Returns:
            h_new: Updated node features
            x_new: Updated coordinates
        """
        num_nodes = h.shape[0]
        
        # Aggregate messages by relation
        total_messages = torch.zeros_like(h)
        coord_updates = torch.zeros_like(x)
        
        for rel in range(4):  # Relations r0-r3
            edge_type = (node_type, f'r{rel}', node_type)
            
            # Skip if this edge type doesn't exist or has no edges
            if edge_type not in hetero_data.edge_types:
                continue
                
            edge_index = hetero_data[edge_type].edge_index
            if edge_index.size(1) == 0:
                continue
                
            # Get edge attributes
            edge_attr = hetero_data[edge_type].edge_attr
            
            row, col = edge_index

            # print(len(edge_attr), len(edge_index[1]), edge_type) 
            # debug: the above lengths should be the same
            
            # Compute coordinate differences and distances
            coord_diff = x[row] - x[col]  # δ_ij
            distances = torch.norm(coord_diff, dim=1, keepdim=True)  # d_ij
            
            if self.normalize:
                coord_diff = coord_diff / (distances + self.epsilon)
            
            # Compute RBF distance encoding
            rbf_dist = self.compute_rbf(distances.squeeze())

            # Fix for RBF distance tensor dimensions
            if rbf_dist.dim() == 1:
                # If rbf_dist is 1D, expand it to match the batch dimension
                rbf_dist = rbf_dist.unsqueeze(0)  # Shape: [1, 16]
            
            # Prepare message input
            try:
                message_input = torch.cat([
                    h[row],           # source node features
                    h[col],           # target node features  
                    edge_attr,        # edge features
                    rbf_dist          # RBF distance encoding
                ], dim=1)
            except RuntimeError as e:
                print(f"[DEBUG ResMP] Message input concatenation error: {e}")
                print(f"[DEBUG ResMP] h[row] shape: {h[row].shape}")
                print(f"[DEBUG ResMP] h[col] shape: {h[col].shape}")  
                print(f"[DEBUG ResMP] edge_attr shape: {edge_attr.shape}")
                print(f"[DEBUG ResMP] rbf_dist shape: {rbf_dist.shape}")
                raise e
            
            # Compute messages: m_ij^(r)
            rel_str = str(rel)
            messages = self.message_mlps[rel_str](message_input)
            
            # Compute coordinate scaling: s_ij^(r)
            coord_weights = self.coord_mlps[rel_str](messages)
            
            # Aggregate messages to nodes
            rel_messages = scatter_add(messages, col, dim=0, dim_size=num_nodes)
            total_messages += rel_messages
            
            # Aggregate coordinate updates
            coord_update = coord_diff * coord_weights
            if self.coords_agg == 'mean':
                # Normalize by degree for each node
                degree = scatter_add(torch.ones_like(col, dtype=torch.float), col, 
                                   dim=0, dim_size=num_nodes).unsqueeze(1)
                coord_update = coord_update / (degree[col] + self.epsilon)
            
            rel_coord_updates = scatter_add(coord_update, col, dim=0, dim_size=num_nodes)
            coord_updates += rel_coord_updates
        
        # Update node features
        node_input = torch.cat([h, total_messages], dim=1)
        h_new = self.node_mlp(node_input)
        
        if self.residual:
            h_new = h + h_new
            
        # Update coordinates  
        x_new = x + coord_updates
        
        return h_new, x_new


class ResMP(nn.Module):
    """
    Residue-level Message Passing module with multi-relational EGNN
    Processes HeteroData directly without requiring edge_lists
    """
    
    def __init__(self, node_dim=105, edge_dim=100, hidden_dim=128, num_layers=4,
                 num_relations=4, act_fn=nn.SiLU(), residual=True, 
                 normalize=False, dropout=0.1, layer_norm=True):
        super().__init__()
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_relations = num_relations
        self.layer_norm = layer_norm
        
        # Input projection to working dimension
        self.node_proj_in = nn.Linear(node_dim, hidden_dim)
        
        # Multi-relational EGNN layers
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(
                MultiRelationalEGNNLayer(
                    node_dim=hidden_dim,
                    edge_dim=edge_dim,
                    hidden_dim=hidden_dim,
                    num_relations=num_relations,
                    act_fn=act_fn,
                    residual=residual,
                    normalize=normalize
                )
            )
        
        # Layer normalization for each layer
        if layer_norm:
            self.layer_norms = nn.ModuleList([
                nn.LayerNorm(hidden_dim) for _ in range(num_layers)
            ])
        
        # Dropout
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
        
        # Output projection
        self.node_proj_out = nn.Linear(hidden_dim, node_dim)
    
    def forward(self, hetero_data, node_type):
        """
        Process residue-level interactions directly from HeteroData
        
        Args:
            hetero_data: PyG HeteroData object
            node_type: 'ag_res' or 'ab_res'
        """
        # Extract data directly from HeteroData
        h = self.node_proj_in(hetero_data[node_type].x)
        x = hetero_data[node_type].pos.clone()
        
        # Process through layers
        for i, layer in enumerate(self.layers):
            h, x = layer(h, x, hetero_data, node_type)
            if self.layer_norm:
                h = self.layer_norms[i](h)
            if self.dropout:
                h = self.dropout(h)
        
        # Update HeteroData and return projected output
        h_out = self.node_proj_out(h)  # 128 -> 105
        hetero_data[node_type].x = h_out
        hetero_data[node_type].pos = x
        
        return h_out, x







def test_res_mp_with_hetero_data():
    """Test ResMP with HeteroData structure"""
    print("Testing ResMP with HeteroData...")
    
    from torch_geometric.data import HeteroData
    
    # Create dummy HeteroData
    hetero_data = HeteroData()
    
    # Add residue nodes
    num_residues = 20
    hetero_data['ag_res'].x = torch.randn(num_residues, 105)
    hetero_data['ag_res'].pos = torch.randn(num_residues, 3)
    
    # Add residue relations
    for rel in range(4):
        edge_type = ('ag_res', f'r{rel}', 'ag_res')
        num_edges = 10 + rel * 5  # Different number of edges per relation
        
        # Create edge index
        edge_index = torch.randint(0, num_residues, (2, num_edges))
        hetero_data[edge_type].edge_index = edge_index
        
        # Create edge attributes
        hetero_data[edge_type].edge_attr = torch.randn(num_edges, 100)
    
    # Create ResMP
    res_mp = ResMP(
        node_dim=105,
        edge_dim=100,
        hidden_dim=128,
        num_layers=2,
        num_relations=4
    )
    
    # Test forward pass
    h_out, x_out = res_mp(hetero_data, 'ag_res')
    
    print(f"✓ ResMP forward pass successful!")
    print(f"Input residues: {num_residues}")
    print(f"Output features: {h_out.shape}")
    print(f"Output coordinates: {x_out.shape}")
    
    # Calculate coordinate changes
    pos_diff = torch.norm(hetero_data['ag_res'].pos - x_out, dim=1)
    print(f"Coordinate changes - mean: {pos_diff.mean():.4f}, max: {pos_diff.max():.4f}")
    
    return res_mp, h_out, x_out


if __name__ == "__main__":
    test_res_mp_with_hetero_data()


