import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d
import torch.nn as nn
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import get_laplacian, to_dense_adj


class LaplacianPE(BaseTransform):
    """Laplacian Positional Encoding transform.
    
    Computes the eigenvectors of the graph Laplacian matrix and uses them
    as positional encodings for the nodes.
    """
    def __init__(self, k=8, with_pos_enc=True, with_proj=False, with_virtual=False):
        self.k = k
        self.with_pos_enc = with_pos_enc
        self.with_proj = with_proj
        self.with_virtual = with_virtual

    def __call__(self, data):
        original_num_nodes = data.num_nodes
        edge_index, edge_weight = get_laplacian(
            data.edge_index, normalization='sym', 
            num_nodes=original_num_nodes
        )
        
        adj = to_dense_adj(edge_index, edge_attr=edge_weight)[0]
        
        try:
            eigvals, eigvecs = torch.linalg.eigh(adj)
        except RuntimeError:
            eigvecs = torch.zeros((original_num_nodes, self.k))
            print(f"Warning: Eigendecomposition failed for a graph. Using zeros.")
            if self.with_virtual:
                eigvecs = torch.cat([eigvecs, torch.zeros((1, self.k))], dim=0)
            data.pos_enc = eigvecs
            edge_attr = torch.zeros((data.edge_index.size(1), self.k))
            data.edge_attr = edge_attr
            return data

        idx = torch.argsort(eigvals)
        eigvecs = eigvecs[:, idx]
        
        # Get top k eigenvectors (excluding first one as it's constant)
        eigvecs_k = eigvecs[:, 1:self.k+1]

        if self.with_virtual:
            # Create virtual node positional encoding as mean of neighbors
            virtual_pos_enc = eigvecs_k.mean(dim=0, keepdim=True)  # [1, k]
            eigvecs_k = torch.cat([eigvecs_k, virtual_pos_enc], dim=0)
            
            # Create edges from virtual node to all other nodes
            virtual_node_idx = original_num_nodes  # Index of the virtual node
            
            # Create source and destination indices for virtual edges
            src_to_virtual = torch.arange(original_num_nodes, device=data.edge_index.device)
            virtual_to_src = torch.full((original_num_nodes,), virtual_node_idx, 
                                      device=data.edge_index.device)
            
            # Create bidirectional virtual edges
            virtual_edges = torch.stack([
                torch.cat([virtual_to_src, src_to_virtual]),  # Virtual node to others
                torch.cat([src_to_virtual, virtual_to_src])   # Others to virtual node
            ], dim=0)
            
            # Combine original and virtual edges
            data.edge_index = torch.cat([data.edge_index, virtual_edges], dim=1)
            
            # Add virtual node feature
            virtual_feature = data.x.mean(dim=0, keepdim=True)
            data.x = torch.cat([data.x, virtual_feature], dim=0)
            
            # Update number of nodes
            data.num_nodes = original_num_nodes + 1

        if self.with_pos_enc:
            data.pos_enc = eigvecs_k
        else:
            data.pos_enc = torch.zeros_like(eigvecs_k)
        
        if self.with_proj:
            # Compute edge attributes as element-wise product of source and target node eigenvectors
            edge_attr = eigvecs_k[data.edge_index[0]] * eigvecs_k[data.edge_index[1]]
            data.edge_attr = edge_attr
        else:
            edge_attr = torch.zeros((data.edge_index.size(1), self.k), device=data.edge_index.device)
            data.edge_attr = edge_attr
        
        return data


class EGNNLayer(MessagePassing):
    """E(n) Equivariant Graph Neural Network Layer adapted for PyTorch Geometric.
    
    This implementation follows the EGNN architecture described in the paper
    "E(n) Equivariant Graph Neural Networks" by Satorras et al.
    """
    def __init__(self, node_dim, edge_dim=None, hidden_dim=128, dropout=0.0, 
                 norm_features=False, norm_coords=False, 
                 coord_weights_clamp_value=None, k=8, update_coords=True):
        super(EGNNLayer, self).__init__(aggr='add')
        
        self.node_dim = node_dim
        self.edge_dim = edge_dim if edge_dim is not None else 0
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.norm_features = norm_features
        self.norm_coords = norm_coords
        self.coord_weights_clamp_value = coord_weights_clamp_value
        self.k = k  # Number of eigenvalues to use
        self.update_coords = update_coords  # Flag to enable/disable coordinate updates
        
        # Node feature transformation
        self.node_mlp_in = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim) if norm_features else nn.Identity()
        )
        
        # Eigenvalue transformation for alternative distance calculation
        self.eig_mlp = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, k)
        )
        
        # Edge MLP (m_ij)
        edge_input_dim = hidden_dim * 2 + k  # Using k eigenvalues instead of single distance
        if self.edge_dim > 0:
            edge_input_dim += self.edge_dim
            
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim) if norm_features else nn.Identity()
        )
        
        # Coordinate update MLP (phi_x)
        if self.update_coords:
            self.coord_mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1, bias=False),  # No bias for better stability
                nn.Tanh()  # Add tanh for stable coordinate updates
            )
        
        # Node feature update MLP (phi_h)
        self.node_mlp_out = nn.Sequential(
            nn.Linear(hidden_dim + node_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim) if norm_features else nn.Identity(),
            nn.Linear(hidden_dim, node_dim)
        )
        
        # Feature layernorm if requested
        self.feature_norm = nn.LayerNorm(node_dim) if norm_features else nn.Identity()
        
        # Initialize weights with small values for better stability
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights with small values for better stability."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, pos, edge_index, edge_attr=None, batch=None, eigvals=None):
        """
        Args:
            x: Node features [num_nodes, node_dim]
            pos: Node coordinates [num_nodes, 3 or k]
            edge_index: Graph connectivity [2, num_edges]
            edge_attr: Edge features [num_edges, edge_dim]
            batch: Batch indices [num_nodes]
            eigvals: Eigenvalues of the graph Laplacian [num_graphs * k]
        """
        # First, transform node features
        h = self.node_mlp_in(x)
        
        # Prepare for message passing
        row, col = edge_index
        
        # Compute messages and coordinate updates
        msg_feat, coord_update = self.message_and_aggregate(h, pos, edge_index, edge_attr, batch, eigvals)
        
        # Update node features with residual connection
        x_updated = x + self.node_mlp_out(torch.cat([msg_feat, x], dim=1))
        
        # Update coordinates with small scaling factor for stability, only if enabled
        if self.update_coords:
            scaling_factor = 0.1  # Small scaling factor to prevent large coordinate changes
            pos_updated = pos + scaling_factor * coord_update
        else:
            # If coordinate updates are disabled, return pos unchanged
            pos_updated = pos
        
        # Apply feature normalization if requested
        x_updated = self.feature_norm(x_updated)
        
        return x_updated, pos_updated
    
    def message_and_aggregate(self, h, pos, edge_index, edge_attr=None, batch=None, eigvals=None):
        row, col = edge_index
        
        # Get node features for message computation
        h_i, h_j = h[row], h[col]
        
        # Get coordinates for distance computation
        pos_i, pos_j = pos[row], pos[col]
        
        # Compute distances - use alternative method if eigvals are provided
        if eigvals is not None and batch is not None:
            # Alternative distance calculation using eigenvalues
            weighted_dots, coord_diff = self.coord2radial(edge_index, pos, batch, eigvals)
            dist_features = weighted_dots  # [num_edges, k]
        else:
            # Standard Euclidean distance calculation
            coord_diff = pos_i - pos_j
            # Add a small epsilon to avoid numerical instability
            eps = 1e-8
            dist_sq = torch.sum(coord_diff ** 2, dim=1, keepdim=True) + eps
            
            # Normalize coordinates if requested
            if self.norm_coords:
                norm = torch.sqrt(dist_sq)
                coord_diff = coord_diff / norm
                
            # Simple distance feature
            dist_features = dist_sq
            
            # For the standard case, expand dist_sq to match the eigenvalue version's dimension
            # This ensures edge_features can handle both cases
            if dist_features.size(1) == 1:
                dist_features = dist_features.expand(-1, self.k)
        
        # Compute edge features by concatenating node features and distance
        if edge_attr is not None:
            # Ensure edge_attr is the correct size
            edge_attr_i = edge_attr[row] if edge_attr.size(0) == h.size(0) else edge_attr
            edge_features = torch.cat([h_i, h_j, dist_features, edge_attr_i], dim=1)
        else:
            edge_features = torch.cat([h_i, h_j, dist_features], dim=1)
        
        # Apply gradients clipping for stability
        with torch.no_grad():
            if torch.isnan(edge_features).any():
                edge_features = torch.nan_to_num(edge_features, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Compute edge messages
        m_ij = self.edge_mlp(edge_features)
        
        # Compute coordinate updates only if enabled
        if self.update_coords:
            # Compute coordinate influence weights
            coord_weights = self.coord_mlp(m_ij)
            
            # Optionally clamp coordinate weights for stability
            if self.coord_weights_clamp_value is not None:
                coord_weights = torch.clamp(
                    coord_weights, 
                    -self.coord_weights_clamp_value, 
                    self.coord_weights_clamp_value
                )
            
            # Scale the relative position by the computed weight
            weighted_rel_pos = coord_diff * coord_weights
            
            # Check for NaN values and replace them
            with torch.no_grad():
                if torch.isnan(weighted_rel_pos).any():
                    weighted_rel_pos = torch.nan_to_num(weighted_rel_pos, nan=0.0, posinf=0.0, neginf=0.0)
        else:
            # If coordinate updates are disabled, create dummy zero tensor
            weighted_rel_pos = torch.zeros_like(coord_diff)
        
        # Check for NaN values in messages
        with torch.no_grad():
            if torch.isnan(m_ij).any():
                m_ij = torch.nan_to_num(m_ij, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Aggregate node messages
        aggr_msg = torch.zeros_like(h)
        aggr_msg.index_add_(0, row, m_ij)
        
        # Aggregate coordinate updates
        aggr_coords = torch.zeros_like(pos)
        aggr_coords.index_add_(0, row, weighted_rel_pos)
        
        return aggr_msg, aggr_coords
    
    def coord2radial(self, edge_index, coord, batch, eigvals):
        """Compute alternative distance features using eigenvalues."""
        row, col = edge_index
        src, dst = edge_index
        
        # Get coordinate differences
        coord_diff = coord[row] - coord[col]
        
        # Get graph indices for each edge source node
        edge_graph_idx = batch[src]
        
        # Reshape eigenvalues
        num_graphs = torch.max(batch) + 1
        k = len(eigvals) // num_graphs
        eigvals_reshaped = eigvals.view(num_graphs, k)
        
        # Map to edges
        edge_eigvals = eigvals_reshaped[edge_graph_idx]  # [num_edges, k]
        
        # Apply transformation function
        transformed_eigvals = self.eig_mlp(edge_eigvals.unsqueeze(2)).squeeze(-1)  # [num_edges, k, k]
        
        # Use einsum to compute k weighted dot products
        # This computes weighted relationships between node coordinates
        if coord.size(1) >= k:
            # Ensure we're only using up to k dimensions of coordinates
            coord_limited = coord[:, :k]
            
            # Compute weighted dot products
            weighted_dots = torch.einsum('ekj,ej,ej->ek', 
                                transformed_eigvals,  # [num_edges, k, k]
                                coord_limited[src],   # [num_edges, k]
                                coord_limited[dst])   # [num_edges, k]
        else:
            # If we have fewer coordinate dimensions, pad them
            coord_padded = torch.cat([
                coord, 
                torch.zeros(coord.size(0), k - coord.size(1), device=coord.device)
            ], dim=1)
            
            weighted_dots = torch.einsum('ekj,ej,ej->ek', 
                                transformed_eigvals,  # [num_edges, k, k]
                                coord_padded[src],    # [num_edges, k]
                                coord_padded[dst])    # [num_edges, k]
        
        return weighted_dots, coord_diff


class EGNNWithPE(nn.Module):
    """EGNN model for MNIST Superpixels with positional encoding."""
    def __init__(self, num_features=1, pos_enc_dim=8, hidden_dim=64, num_classes=10, dropout=0.2, 
                 num_layers=3, norm_features=True, norm_coords=True, coord_weights_clamp_value=2.0,
                 update_coords=True):
        super(EGNNWithPE, self).__init__()
        
        # Input: node features + positional encoding
        input_dim = num_features + pos_enc_dim
        
        self.edge_dim = pos_enc_dim  # Use positional encoding dimension for edge features
        self.dropout = dropout
        self.pos_enc_dim = pos_enc_dim
        self.update_coords = update_coords  # Flag to enable/disable coordinate updates
        
        # EGNN layers
        self.layers = nn.ModuleList()
        
        # First layer takes input_dim
        self.layers.append(
            EGNNLayer(
                node_dim=input_dim, 
                edge_dim=self.edge_dim,
                hidden_dim=hidden_dim,
                dropout=dropout,
                norm_features=norm_features,
                norm_coords=norm_coords,
                coord_weights_clamp_value=coord_weights_clamp_value,
                k=pos_enc_dim,
                update_coords=update_coords
            )
        )
        
        # Subsequent layers
        for _ in range(num_layers - 1):
            self.layers.append(
                EGNNLayer(
                    node_dim=input_dim,
                    edge_dim=self.edge_dim,
                    hidden_dim=hidden_dim,
                    dropout=dropout,
                    norm_features=norm_features,
                    norm_coords=norm_coords,
                    coord_weights_clamp_value=coord_weights_clamp_value,
                    k=pos_enc_dim,
                    update_coords=update_coords
                )
            )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
        # Initialize weights with small values for better stability
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights with small values for better stability."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, data):
        # Extract the eigvals from data if they exist
        eigvals = data.eigvals if hasattr(data, 'eigvals') else None
        
        # Concatenate node features with positional encoding
        x = torch.cat([data.x, torch.zeros_like(data.pos_enc)], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        batch = data.batch
        
        # Initialize coordinates from positional encodings with normalization
        # Use small initial values to avoid instability
        pos = data.pos_enc.clone()  # Scale down for stability
            
        # Apply EGNN layers with careful gradient checks
        for i, layer in enumerate(self.layers):
            # Apply layer with gradient clipping for stability
            with torch.no_grad():
                if torch.isnan(x).any():
                    x = torch.nan_to_num(x, nan=0.0)
                if torch.isnan(pos).any():
                    pos = torch.nan_to_num(pos, nan=0.0)
            
            x_prev, pos_prev = x.clone(), pos.clone()
            x, pos = layer(x, pos, edge_index, edge_attr, batch, eigvals)
            
            # If NaNs appear, revert to previous state
            with torch.no_grad():
                if torch.isnan(x).any() or torch.isnan(pos).any():
                    print(f"NaN detected in layer {i+1}, reverting to previous state")
                    x = torch.nan_to_num(x_prev, nan=0.0)
                    pos = torch.nan_to_num(pos_prev, nan=0.0)
            
            # Apply dropout only to features
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Pool node features to graph-level representation
        x = global_mean_pool(x, batch)
        
        # Apply classifier
        x = self.classifier(x)
        
        # Final check for NaNs before softmax
        with torch.no_grad():
            if torch.isnan(x).any():
                x = torch.nan_to_num(x, nan=0.0)
        
        return F.log_softmax(x, dim=1)