# Add this to your models.py file

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_mean_pool

from e3nn import o3
from e3nn.nn import FullyConnectedNet
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
E3NN_AVAILABLE = True
from egnn_pytorch import EGNN
from torch import nn



# Memory-efficient EGNN for large graphs
class BatchedEGNNLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, batch_size=1000, dropout=0.1):
        super().__init__()
        self.feat_proj = torch.nn.Linear(in_channels, hidden_channels)
        self.egnn = EGNN(
            dim=hidden_channels,
            m_dim=hidden_channels // 2,
            dropout=dropout,
            update_feats=True,
            update_coors=False
        )
        self.norm = torch.nn.LayerNorm(hidden_channels)
        self.batch_size = batch_size
        self.dropout = dropout
        self.decoder = torch.nn.Linear(hidden_channels, hidden_channels)

        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels*2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )

        # torch.nn.Linear(hidden_channels*2, 1)
        

    def encode(self, x, edge_index):
        num_nodes = x.size(0)
        device = x.device
        
        # Initialize output
        all_embeddings = []
        
        # Process in batches
        for start_idx in range(0, num_nodes, self.batch_size):
            end_idx = min(start_idx + self.batch_size, num_nodes)
            batch_size_actual = end_idx - start_idx
            
            # Extract batch
            batch_x = x[start_idx:end_idx]
            batch_h = self.feat_proj(batch_x).unsqueeze(0)
            batch_coords = self.feat_proj(batch_x).unsqueeze(0)
            
            # Filter edges for this batch
            mask = ((edge_index[0] >= start_idx) & (edge_index[0] < end_idx) & 
                   (edge_index[1] >= start_idx) & (edge_index[1] < end_idx))
            
            if mask.any():
                batch_edges = edge_index[:, mask] - start_idx
                # Create adjacency matrix
                adj_matrix = torch.zeros(1, batch_size_actual, batch_size_actual, device=device)
                adj_matrix[0, batch_edges[0], batch_edges[1]] = 1.0
                adj_matrix[0, batch_edges[1], batch_edges[0]] = 1.0
            else:
                adj_matrix = torch.zeros(1, batch_size_actual, batch_size_actual, device=device)
            
            # EGNN forward
            batch_h, _ = self.egnn(batch_h, batch_coords, adj_mat=adj_matrix)
            
            # Process and store
            batch_h = batch_h.squeeze(0)
            batch_h = self.norm(batch_h)
            batch_h = F.dropout(batch_h, p=self.dropout, training=self.training)
            
            all_embeddings.append(batch_h)
        
        # Concatenate all batches
        return torch.cat(all_embeddings, dim=0)

    def decode(self, z, edge_label_index):
        # return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=1)
        return self.predictor(torch.cat([z[edge_label_index[0]], z[edge_label_index[1]]], dim=-1)).view(-1)

    def decode_prob(self, z, edge_label_index):
        return torch.sigmoid(self.decode(z, edge_label_index))

    def decode_break(self, z, edge_label_index):
        predict = self.decoder(z[edge_label_index[0]])
        scores = predict * z[edge_label_index[1]]
        scores = scores.sum(dim=1)

        return torch.sigmoid(scores)


class EquiformerV2LinkPredictor(torch.nn.Module):
    """
    EquiformerV2-inspired model for link prediction using coordinates as features
    """
    def __init__(self, in_channels, hidden_channels, num_layers=3, max_ell=2, dropout=0.1):
        super().__init__()
        
        if not E3NN_AVAILABLE:
            raise ImportError("e3nn is required for EquiformerV2. Install with: pip install e3nn")
        
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.max_ell = max_ell
        self.dropout = dropout
        
        # Define irreducible representations
        self.irreps_in = o3.Irreps(f"{in_channels}x0e")  # Scalar features
        self.irreps_hidden = o3.Irreps(f"{hidden_channels}x0e + {hidden_channels//2}x1o + {hidden_channels//4}x2e")
        self.irreps_out = o3.Irreps(f"{hidden_channels}x0e")
        
        # Input projection
        self.input_proj = o3.Linear(self.irreps_in, self.irreps_hidden)
        
        # EquiformerV2 layers
        self.layers = torch.nn.ModuleList([
            EquiformerV2Layer(
                irreps_in=self.irreps_hidden,
                irreps_out=self.irreps_hidden,
                max_ell=max_ell,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])
        
        # Output projection to scalars only
        self.output_proj = o3.Linear(self.irreps_hidden, self.irreps_out)
        
        # Layer normalization
        self.layer_norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(hidden_channels)
            for _ in range(num_layers + 1)
        ])

    def encode(self, x, edge_index):
        """
        Encode node features using EquiformerV2
        Args:
            x: Node coordinates [num_nodes, 3] (assuming 3D coordinates)
            edge_index: Edge indices [2, num_edges]
        Returns:
            z: Node embeddings [num_nodes, hidden_channels]
        """
        # If coordinates are not 3D, pad or project to 3D
        if x.shape[1] == 2:
            # Pad 2D coordinates to 3D
            x = F.pad(x, (0, 1), value=0.0)
        elif x.shape[1] > 3:
            # Take first 3 dimensions
            x = x[:, :3]
        
        # Compute edge vectors and distances
        row, col = edge_index
        edge_vec = x[row] - x[col]  # [num_edges, 3]
        edge_len = torch.norm(edge_vec, dim=1, keepdim=True)  # [num_edges, 1]
        
        # Initialize node features (using coordinate norms as scalar features)
        node_features = torch.norm(x, dim=1, keepdim=True)  # [num_nodes, 1]
        
        # Extend to match input irreps
        if node_features.shape[1] < self.irreps_in.dim:
            # Pad with coordinate components
            additional_features = x[:, :min(3, self.irreps_in.dim - 1)]
            node_features = torch.cat([node_features, additional_features], dim=1)
            
        # Ensure correct dimension
        if node_features.shape[1] < self.irreps_in.dim:
            padding = self.irreps_in.dim - node_features.shape[1]
            node_features = F.pad(node_features, (0, padding), value=0.0)
        elif node_features.shape[1] > self.irreps_in.dim:
            node_features = node_features[:, :self.irreps_in.dim]
        
        # Project to hidden space
        h = self.input_proj(node_features)
        h_scalars = h[:, :self.hidden_channels]  # Extract scalar part
        h_scalars = self.layer_norms[0](h_scalars)
        
        # Apply EquiformerV2 layers
        for i, layer in enumerate(self.layers):
            h_residual = h
            h = layer(h, edge_index, edge_vec, edge_len)
            
            # Residual connection (only for scalar part)
            h_scalars_new = h[:, :self.hidden_channels]
            h_scalars_new = self.layer_norms[i + 1](h_scalars_new)
            
            if i > 0:  # Add residual from second layer
                h_scalars_new = h_scalars_new + h_residual[:, :self.hidden_channels]
            
            # Update full representation
            h = torch.cat([h_scalars_new, h[:, self.hidden_channels:]], dim=1)
            
            # Dropout
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Project to output (scalars only for link prediction)
        z = self.output_proj(h)
        
        return z

    def decode(self, z, edge_label_index):
        """Decode edge probabilities using dot product"""
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=1)

    def decode_prob(self, z, edge_label_index):
        """Get edge probabilities using sigmoid"""
        return torch.sigmoid(self.decode(z, edge_label_index))
    

class EquiformerV2Layer(MessagePassing):
    """
    Simplified EquiformerV2 layer for link prediction
    """
    def __init__(self, irreps_in, irreps_out, max_ell=2, dropout=0.1):
        super().__init__(aggr='add', node_dim=0)
        
        self.irreps_in = irreps_in
        self.irreps_out = irreps_out
        self.max_ell = max_ell
        
        # Define edge embedding irreps (including geometric features)
        self.irreps_edge = o3.Irreps(f"32x0e + 16x1o + 8x2e")
        
        # Spherical harmonics for edge directions
        self.sh = o3.SphericalHarmonics(list(range(max_ell + 1)), normalize=True, normalization='component')
        
        # Edge embedding network
        self.edge_embedding = FullyConnectedNet([1, 32, 32], torch.nn.SiLU())
        
        # Tensor product for message passing
        self.tp = o3.FullTensorProduct(irreps_in, self.irreps_edge)
        
        # Linear layer for message aggregation
        self.linear = o3.Linear(self.tp.irreps_out, irreps_out)
        
        # Self-interaction
        self.self_linear = o3.Linear(irreps_in, irreps_out)
        
        self.dropout = dropout

    def forward(self, x, edge_index, edge_vec, edge_len):
        """
        Forward pass of EquiformerV2 layer
        """
        # Normalize edge vectors
        edge_vec_normalized = edge_vec / (edge_len + 1e-8)
        
        # Compute spherical harmonics
        edge_sh = self.sh(edge_vec_normalized)  # [num_edges, irreps_edge_dim]
        
        # Edge length embedding
        edge_len_embedding = self.edge_embedding(edge_len)  # [num_edges, 32]
        
        # Combine edge features (simplified)
        edge_features = torch.cat([edge_len_embedding, edge_sh], dim=1)
        
        # Ensure edge features match expected irreps dimension
        if edge_features.shape[1] < self.irreps_edge.dim:
            padding = self.irreps_edge.dim - edge_features.shape[1]
            edge_features = F.pad(edge_features, (0, padding), value=0.0)
        elif edge_features.shape[1] > self.irreps_edge.dim:
            edge_features = edge_features[:, :self.irreps_edge.dim]
        
        # Message passing
        out = self.propagate(edge_index, x=x, edge_attr=edge_features)
        
        # Self-interaction
        out = out + self.self_linear(x)
        
        return out

    def message(self, x_j, edge_attr):
        """
        Compute messages between nodes
        """
        # Tensor product between node features and edge features
        messages = self.tp(x_j, edge_attr)
        messages = self.linear(messages)
        
        return F.dropout(messages, p=self.dropout, training=self.training)



class BatchedEquiformerV2LinkPredictor(torch.nn.Module):
    """
    Memory-efficient batched EquiformerV2 for large graphs
    Processes nodes in chunks to avoid OOM errors
    """
    def __init__(self, in_channels, hidden_channels, num_layers=3, max_ell=2, dropout=0.1, batch_size=500):
        super().__init__()
        
        if not E3NN_AVAILABLE:
            raise ImportError("e3nn is required for EquiformerV2. Install with: pip install e3nn")
        
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.max_ell = max_ell
        self.dropout = dropout
        self.batch_size = batch_size
        
        # Define irreducible representations
        self.irreps_in = o3.Irreps(f"{in_channels}x0e")  # Scalar features
        self.irreps_hidden = o3.Irreps(f"{hidden_channels}x0e + {hidden_channels//2}x1o + {hidden_channels//4}x2e")
        self.irreps_out = o3.Irreps(f"{hidden_channels}x0e")
        
        # Input projection
        self.input_proj = o3.Linear(self.irreps_in, self.irreps_hidden)
        
        # Single EquiformerV2 layer (we'll apply it multiple times)
        self.layer = BatchedEquiformerV2Layer(
            irreps_in=self.irreps_hidden,
            irreps_out=self.irreps_hidden,
            max_ell=max_ell,
            dropout=dropout
        )
        
        # Output projection to scalars only
        self.output_proj = o3.Linear(self.irreps_hidden, self.irreps_out)
        
        # Layer normalization
        self.layer_norm = torch.nn.LayerNorm(hidden_channels)
        self.decoder = torch.nn.Linear(hidden_channels, hidden_channels)

        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels*2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )


    def encode(self, x, edge_index):
        """
        Encode node features using batched EquiformerV2
        Args:
            x: Node coordinates [num_nodes, in_channels]
            edge_index: Edge indices [2, num_edges]
        Returns:
            z: Node embeddings [num_nodes, hidden_channels]
        """
        num_nodes = x.shape[0]
        device = x.device
        
        # Process in batches to avoid memory issues
        all_embeddings = []
        
        for start_idx in range(0, num_nodes, self.batch_size):
            end_idx = min(start_idx + self.batch_size, num_nodes)
            batch_size_actual = end_idx - start_idx
            
            # Extract batch coordinates
            batch_x = x[start_idx:end_idx]
            
            # Handle coordinate dimensions
            if batch_x.shape[1] == 2:
                batch_pos = F.pad(batch_x, (0, 1), value=0.0)
            elif batch_x.shape[1] > 3:
                batch_pos = batch_x[:, :3]
            else:
                batch_pos = F.pad(batch_x, (0, 3 - batch_x.shape[1]), value=0.0)
            
            # Filter edges for this batch
            mask = ((edge_index[0] >= start_idx) & (edge_index[0] < end_idx) & 
                   (edge_index[1] >= start_idx) & (edge_index[1] < end_idx))
            
            if mask.any():
                batch_edge_index = edge_index[:, mask] - start_idx
            else:
                batch_edge_index = torch.empty(2, 0, dtype=torch.long, device=device)
            
            # Process this batch
            batch_embedding = self._process_batch(batch_x, batch_edge_index, batch_pos)
            all_embeddings.append(batch_embedding)
        
        # Concatenate all batch results
        return torch.cat(all_embeddings, dim=0)

    def _process_batch(self, batch_x, batch_edge_index, batch_pos):
        """Process a single batch through EquiformerV2"""
        # Prepare input features
        batch_size = batch_x.shape[0]
        
        # Initialize node features using coordinate norms
        node_features = torch.norm(batch_x, dim=1, keepdim=True)
        
        # Extend to match input irreps
        if node_features.shape[1] < self.irreps_in.dim:
            additional_features = batch_x[:, :min(batch_x.shape[1], self.irreps_in.dim - 1)]
            node_features = torch.cat([node_features, additional_features], dim=1)
        
        # Ensure correct dimension
        if node_features.shape[1] < self.irreps_in.dim:
            padding = self.irreps_in.dim - node_features.shape[1]
            node_features = F.pad(node_features, (0, padding), value=0.0)
        elif node_features.shape[1] > self.irreps_in.dim:
            node_features = node_features[:, :self.irreps_in.dim]
        
        # Project to hidden space
        h = self.input_proj(node_features)
        
        # Compute edge vectors and distances for this batch
        if batch_edge_index.shape[1] > 0:
            row, col = batch_edge_index
            edge_vec = batch_pos[row] - batch_pos[col]
            edge_len = torch.norm(edge_vec, dim=1, keepdim=True)
        else:
            edge_vec = torch.empty(0, 3, device=batch_x.device)
            edge_len = torch.empty(0, 1, device=batch_x.device)
        
        # Apply EquiformerV2 layers
        for layer_idx in range(self.num_layers):
            h_residual = h
            h = self.layer(h, batch_edge_index, edge_vec, edge_len)
            
            # Extract scalar part and apply layer norm
            h_scalars = h[:, :self.hidden_channels]
            h_scalars = self.layer_norm(h_scalars)
            
            # Residual connection (only for scalar part)
            if layer_idx > 0:
                h_scalars = h_scalars + h_residual[:, :self.hidden_channels]
            
            # Update full representation
            if h.shape[1] > self.hidden_channels:
                h = torch.cat([h_scalars, h[:, self.hidden_channels:]], dim=1)
            else:
                h = h_scalars
            
            # Dropout
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Project to output (scalars only)
        z = self.output_proj(h)
        
        return z

    def decode(self, z, edge_label_index):
        """Decode edge probabilities using dot product"""
        # return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=1)
        return self.predictor(torch.cat([z[edge_label_index[0]], z[edge_label_index[1]]], dim=-1)).view(-1)

    def decode_prob(self, z, edge_label_index):
        """Get edge probabilities using sigmoid"""
        return torch.sigmoid(self.decode(z, edge_label_index))

    def decode_break(self, z, edge_label_index):
        predict = self.decoder(z[edge_label_index[0]])
        scores = predict * z[edge_label_index[1]]
        scores = scores.sum(dim=1)

        return torch.sigmoid(scores)

class BatchedEquiformerV2Layer(MessagePassing):
    """
    Batched EquiformerV2 layer for memory efficiency
    """
    def __init__(self, irreps_in, irreps_out, max_ell=2, dropout=0.1):
        super().__init__(aggr='add', node_dim=0)
        
        self.irreps_in = irreps_in
        self.irreps_out = irreps_out
        self.max_ell = max_ell
        
        # Define edge embedding irreps
        self.irreps_edge = o3.Irreps(f"32x0e + 16x1o + 8x2e")
        
        # Spherical harmonics for edge directions
        self.sh = o3.SphericalHarmonics(list(range(max_ell + 1)), normalize=True, normalization='component')
        
        # Edge embedding network
        self.edge_embedding = FullyConnectedNet([1, 32, 32], torch.nn.SiLU())
        
        # Tensor product for message passing
        self.tp = o3.FullTensorProduct(irreps_in, self.irreps_edge)
        
        # Linear layer for message aggregation
        self.linear = o3.Linear(self.tp.irreps_out, irreps_out)
        
        # Self-interaction
        self.self_linear = o3.Linear(irreps_in, irreps_out)
        
        self.dropout = dropout

    def forward(self, x, edge_index, edge_vec, edge_len):
        """
        Forward pass with pre-computed edge vectors and lengths
        """
        if edge_index.shape[1] == 0:
            # No edges in this batch, just apply self-interaction
            return self.self_linear(x)
        
        # Normalize edge vectors
        edge_vec_normalized = edge_vec / (edge_len + 1e-8)
        
        # Compute spherical harmonics
        edge_sh = self.sh(edge_vec_normalized)
        
        # Edge length embedding
        edge_len_embedding = self.edge_embedding(edge_len)
        
        # Combine edge features
        edge_features = torch.cat([edge_len_embedding, edge_sh], dim=1)
        
        # Ensure edge features match expected irreps dimension
        if edge_features.shape[1] < self.irreps_edge.dim:
            padding = self.irreps_edge.dim - edge_features.shape[1]
            edge_features = F.pad(edge_features, (0, padding), value=0.0)
        elif edge_features.shape[1] > self.irreps_edge.dim:
            edge_features = edge_features[:, :self.irreps_edge.dim]
        
        # Message passing
        out = self.propagate(edge_index, x=x, edge_attr=edge_features)
        
        # Self-interaction
        out = out + self.self_linear(x)
        
        return out

    def message(self, x_j, edge_attr):
        """Compute messages between nodes"""
        # Tensor product between node features and edge features
        messages = self.tp(x_j, edge_attr)
        messages = self.linear(messages)
        
        return F.dropout(messages, p=self.dropout, training=self.training)



class Regressor(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels=1, dropout=0.5):
        super().__init__()
        self.regressor = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_channels // 2, out_channels)
        )

    def forward(self, x):
        return self.regressor(x)
    

class EGNNGraphRegressor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=2, dropout=0.1):
        super().__init__()
        self.feat_proj = torch.nn.Linear(in_channels, hidden_channels)
        self.egnn_layers = torch.nn.ModuleList([
            EGNN(
                dim=hidden_channels,
                m_dim=hidden_channels // 2,
                dropout=dropout,
                update_feats=True,
                update_coors=True,
                fourier_features=4,
                num_nearest_neighbors=0
            )
            for _ in range(num_layers)
        ])
        self.norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(hidden_channels)
            for _ in range(num_layers)
        ])
        self.dropout = dropout
        self.regressor = Regressor(hidden_channels, out_channels=1, dropout=dropout)

    def forward(self, x, edge_index, batch):
        num_nodes = x.size(0)
        h = self.feat_proj(x)
        coords = x
        h = h.unsqueeze(0)
        coords = coords.unsqueeze(0)
        adj_matrix = torch.zeros(1, num_nodes, num_nodes, device=x.device)
        if edge_index.size(1) > 0:
            adj_matrix[0, edge_index[0], edge_index[1]] = 1.0
            adj_matrix[0, edge_index[1], edge_index[0]] = 1.0
        for i, (egnn_layer, norm) in enumerate(zip(self.egnn_layers, self.norms)):
            h_residual = h
            h, coords = egnn_layer(
                feats=h,
                coors=coords,
                adj_mat=adj_matrix
            )
            h_flat = h.squeeze(0)
            h_flat = norm(h_flat)
            if i > 0:
                h_flat = h_flat + h_residual.squeeze(0)
            h_flat = F.dropout(h_flat, p=self.dropout, training=self.training)
            h = h_flat.unsqueeze(0)
        h = h.squeeze(0)
        h = global_mean_pool(h, batch)
        out = self.regressor(h)
        return out



class BatchedEGNNGraphRegressor(torch.nn.Module):
    """
    Memory-efficient EGNN regressor for large graphs.
    Processes nodes in batches to avoid OOM errors.
    """
    def __init__(self, in_channels, hidden_channels, num_layers=2, dropout=0.1, batch_size=1000):
        super().__init__()
        self.feat_proj = torch.nn.Linear(in_channels, hidden_channels)
        self.egnn_layers = torch.nn.ModuleList([
            EGNN(
                dim=hidden_channels,
                m_dim=hidden_channels // 2,
                dropout=dropout,
                update_feats=True,
                update_coors=False,
                fourier_features=4,
                num_nearest_neighbors=0
            )
            for _ in range(num_layers)
        ])
        self.norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(hidden_channels)
            for _ in range(num_layers)
        ])
        self.dropout = dropout
        self.batch_size = batch_size
        self.regressor = Regressor(hidden_channels, out_channels=1, dropout=dropout)

    def forward(self, x, edge_index, batch):
        num_nodes = x.size(0)
        device = x.device
        all_h = []

        # Process nodes in batches
        for start_idx in range(0, num_nodes, self.batch_size):
            end_idx = min(start_idx + self.batch_size, num_nodes)
            batch_x = x[start_idx:end_idx]
            batch_h = self.feat_proj(batch_x).unsqueeze(0)
            batch_coords = batch_x.unsqueeze(0)
            batch_size_actual = end_idx - start_idx

            # Filter edges for this batch
            mask = ((edge_index[0] >= start_idx) & (edge_index[0] < end_idx) &
                    (edge_index[1] >= start_idx) & (edge_index[1] < end_idx))
            batch_edge_index = edge_index[:, mask] - start_idx

            # Create adjacency matrix for batch
            adj_matrix = torch.zeros(1, batch_size_actual, batch_size_actual, device=device)
            if batch_edge_index.size(1) > 0:
                adj_matrix[0, batch_edge_index[0], batch_edge_index[1]] = 1.0
                adj_matrix[0, batch_edge_index[1], batch_edge_index[0]] = 1.0

            # EGNN layers
            for i, (egnn_layer, norm) in enumerate(zip(self.egnn_layers, self.norms)):
                h_residual = batch_h
                batch_h, batch_coords = egnn_layer(
                    feats=batch_h,
                    coors=batch_coords,
                    adj_mat=adj_matrix
                )
                h_flat = batch_h.squeeze(0)
                h_flat = norm(h_flat)
                if i > 0:
                    h_flat = h_flat + h_residual.squeeze(0)
                h_flat = F.dropout(h_flat, p=self.dropout, training=self.training)
                batch_h = h_flat.unsqueeze(0)
            all_h.append(batch_h.squeeze(0))

        # Concatenate all batches
        h = torch.cat(all_h, dim=0)
        h = global_mean_pool(h, batch)
        out = self.regressor(h)
        return out



class EquiformerGraphRegressor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=3, max_ell=2, dropout=0.1):
        super().__init__()
        if not E3NN_AVAILABLE:
            raise ImportError("e3nn is required for EquiformerV2. Install with: pip install e3nn")
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.max_ell = max_ell
        self.dropout = dropout
        self.irreps_in = o3.Irreps(f"{in_channels}x0e")
        self.irreps_hidden = o3.Irreps(f"{hidden_channels}x0e + {hidden_channels//2}x1o + {hidden_channels//4}x2e")
        self.irreps_out = o3.Irreps(f"{hidden_channels}x0e")
        self.input_proj = o3.Linear(self.irreps_in, self.irreps_hidden)
        self.layers = torch.nn.ModuleList([
            EquiformerV2Layer(
                irreps_in=self.irreps_hidden,
                irreps_out=self.irreps_hidden,
                max_ell=max_ell,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])
        self.output_proj = o3.Linear(self.irreps_hidden, self.irreps_out)
        self.layer_norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(hidden_channels)
            for _ in range(num_layers + 1)
        ])
        self.regressor = Regressor(hidden_channels, out_channels=1, dropout=dropout)

    def forward(self, x, edge_index, batch):
        if x.shape[1] == 2:
            x = F.pad(x, (0, 1), value=0.0)
        elif x.shape[1] > 3:
            x = x[:, :3]
        row, col = edge_index
        edge_vec = x[row] - x[col]
        edge_len = torch.norm(edge_vec, dim=1, keepdim=True)
        node_features = torch.norm(x, dim=1, keepdim=True)
        if node_features.shape[1] < self.irreps_in.dim:
            additional_features = x[:, :min(3, self.irreps_in.dim - 1)]
            node_features = torch.cat([node_features, additional_features], dim=1)
        if node_features.shape[1] < self.irreps_in.dim:
            padding = self.irreps_in.dim - node_features.shape[1]
            node_features = F.pad(node_features, (0, padding), value=0.0)
        elif node_features.shape[1] > self.irreps_in.dim:
            node_features = node_features[:, :self.irreps_in.dim]
        h = self.input_proj(node_features)
        h_scalars = h[:, :self.hidden_channels]
        h_scalars = self.layer_norms[0](h_scalars)
        for i, layer in enumerate(self.layers):
            h_residual = h
            h = layer(h, edge_index, edge_vec, edge_len)
            h_scalars_new = h[:, :self.hidden_channels]
            h_scalars_new = self.layer_norms[i + 1](h_scalars_new)
            if i > 0:
                h_scalars_new = h_scalars_new + h_residual[:, :self.hidden_channels]
            h = torch.cat([h_scalars_new, h[:, self.hidden_channels:]], dim=1)
            h = F.dropout(h, p=self.dropout, training=self.training)
        z = self.output_proj(h)
        z = global_mean_pool(z, batch)
        out = self.regressor(z)
        return out


