import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, GCNConv, global_mean_pool, MessagePassing, BatchNorm
from torch_geometric.utils import add_self_loops, degree
import numpy as np
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
from utils import generate_random_quaternion, rotate_molecule, canonicalize_molecule, preprocess_data_for_MLP, preprocess
from torch_geometric.utils import scatter
import utils
from torch_geometric.utils import to_dense_batch
from typing import Dict,Union
import e3nn
from e3nn import o3
import e3tools
import e3tools.nn
from typing import Callable, Tuple
from torch import Tensor
from torch_geometric.nn import knn_graph
from torchvision import datasets, transforms
import random
import copy

from torch_geometric.data import Batch

torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
# using torch.compile
e3nn.set_optimization_defaults(jit_script_fx=False)
torch.set_float32_matmul_precision("high")

class MoleculeNet(nn.Module):
    def __init__(self, input_dim,hidden_dim,num_classes,num_layers=10):
        """MoleculeNet model with configurable number of GCN layers.

        Args:
            hidden_dim (int): The dimension of the hidden layers.
            num_layers (int): The number of GCN layers to use in the model.
        """
        super().__init__()

        self.num_layers = num_layers

        # Initialize the convolutional layers dynamically based on num_layers
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(input_dim, hidden_dim))  # First layer (input size = 11)
        
        # Add additional convolutional layers based on num_layers
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)  # default qm9

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Pass input through all convolutional layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()

        # Global mean pooling over the graph
        x = global_mean_pool(x, batch)
        
        # Fully connected layers
        x = self.fc1(x).relu()
        x = self.fc2(x)  # Quaternion prediction
        
        return x

# Step 1: Custom GNN layer with position-awareness
class PositionAwareConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # Add aggregation
        self.lin = nn.Linear(in_channels + 1, out_channels)  # Node features + distance
        self.mlp = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

    def forward(self, x, edge_index, pos, **kwargs):
        # Add self-loops to the adjacency matrix
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Compute edge attributes (relative distances)
        row, col = edge_index
        edge_attr = torch.norm(pos[row] - pos[col], dim=-1).unsqueeze(-1)  # Relative distances

        # Propagate messages
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, pos=pos)

    def message(self, x_j, edge_attr):
        # Combine node features and edge attributes
        return self.lin(torch.cat([x_j, edge_attr], dim=-1))

    def update(self, aggr_out):
        # Pass aggregated messages through an MLP
        return self.mlp(aggr_out)

# Step 2: Full MoleculeNet model
class MoleculeNetWithPositions(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=4, use_global_pool=True):
        """
        Args:
            input_dim (int): input node feature dimension (e.g., 11 for QM9)
            hidden_dim (int): hidden dimension size
            num_layers (int): number of position-aware convolutional layers
            num_classes (int): number of output classes, or -1 for quaternion regression
            use_global_pool (bool): if True, use global_mean_pool; else use scatter
        """
        super().__init__()
        self.use_global_pool = use_global_pool

        # Stack of position-aware conv layers
        self.convs = nn.ModuleList()
        self.convs.append(PositionAwareConv(input_dim, hidden_dim))
        for _ in range(num_layers - 1):
            self.convs.append(PositionAwareConv(hidden_dim, hidden_dim))

        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)# if self.as_classification_problem else 4)

    def forward(self, data, **kwargs):
        x, edge_index, pos, batch = data.x, data.edge_index, data.pos, data.batch

        for conv in self.convs:
            x = conv(x, edge_index, pos).relu()

        if self.use_global_pool:
            x = global_mean_pool(x, batch)
        else:
            x = scatter(x, batch, dim=0, reduce='mean')

        x = self.fc1(x).relu()
        x = self.fc2(x)

        return x

class MoleculeGNN(nn.Module):
    def __init__(self, input_dim, edge_dim, hidden_dim, num_classes, num_layers, dropout, 
                 norm="batch", use_residual=True, use_layer_norm=True, use_edge_updates=True, 
                 activation="relu"):
        super(MoleculeGNN, self).__init__()
        
        self.num_layers = num_layers
        self.use_residual = use_residual
        self.use_layer_norm = use_layer_norm
        self.use_edge_updates = use_edge_updates

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        # Input layer
        self.convs.append(GATConv(input_dim, hidden_dim))
        if self.use_layer_norm:
            self.norms.append(nn.LayerNorm(hidden_dim))

        # Hidden layers
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_dim, hidden_dim))
            if self.use_layer_norm:
                self.norms.append(nn.LayerNorm(hidden_dim))

        # Output layer
        self.fc = nn.Linear(hidden_dim, num_classes)

        # Dropout layer
        self.dropout = dropout
        
        # Edge updates
        self.edge_fc = nn.Linear(edge_dim, hidden_dim) if use_edge_updates else None

        # Activation function
        self.activation = getattr(F, activation) if activation in dir(F) else F.relu
        
    def forward(self, data, **kwargs): #x, edge_index, edge_attr, batch, **kwargs):

        data.x = torch.cat([data.x,data.pos], dim=1)
        x = data.x 
        edge_index = data.edge_index 
        edge_attr = data.edge_attr
        batch = data.batch 

        residual = x.clone() if self.use_residual else None
        for i in range(self.num_layers):
            # Edge updates if enabled
            if self.use_edge_updates:
                edge_attr = self.edge_fc(edge_attr)
            # Apply GCNConv layer
            x = self.convs[i](x, edge_index,edge_attr)#, edge_attr)
            x = self.activation(x)
            
            # Optional normalization
            if self.use_layer_norm:
                x = self.norms[i](x)
                
            # Residual connection
            if self.use_residual:
                x = x + residual
                residual = x.clone()

            # Dropout
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Graph-level pooling (global mean pool)
        x = global_mean_pool(x, batch)

        # Output layer (classification/regression)
        x = self.fc(x)
        return x



# Step 1: Define a custom Message Passing Layer with Rotation Equivariance
class EquivariantMessagePassing(MessagePassing):
    def __init__(self, input_dim, hidden_dim):
        super().__init__(aggr='add')  # Aggregate messages with summation
        self.fc_edge = nn.Linear(input_dim, hidden_dim)
        self.fc_node = nn.Linear(hidden_dim, hidden_dim)
        self.fc_vector = nn.Linear(1, hidden_dim)  # Learn from distances

    def forward(self, x, edge_index, edge_attr, **kwargs):
        """
        x: Node features [num_nodes, input_dim]
        edge_index: Graph connectivity [2, num_edges]
        edge_attr: Edge attributes (e.g., distances) [num_edges, 1]
        """
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        """
        x_i: Central node features [num_edges, hidden_dim]
        x_j: Neighbor node features [num_edges, hidden_dim]
        edge_attr: Edge attributes (e.g., distances) [num_edges, 1]
        """
        # Compute messages with equivariant dependence on edge_attr (e.g., distance)
        edge_embedding = self.fc_edge(edge_attr)
        message = x_j + edge_embedding  # Combine features equivariantly
        return message

    def update(self, aggr_out, x):
        """
        aggr_out: Aggregated messages [num_nodes, hidden_dim]
        x: Original node features [num_nodes, input_dim]
        """
        return self.fc_node(aggr_out) + x  # Residual connection for stability

# Step 2: Define the Model
class EquivariantMPNN(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()

        input_dim = 4 #11
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # Message passing layers
        self.mp1 = EquivariantMessagePassing(input_dim, hidden_dim)
        self.mp2 = EquivariantMessagePassing(hidden_dim, hidden_dim)

        # Generic MLP applied to equivariant features
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, data, **kwargs):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Step 1: Message Passing
        x = self.mp1(x, edge_index, edge_attr)
        x = self.mp2(x, edge_index, edge_attr)

        # Step 2: Aggregate node features to a graph-level feature
        graph_feature = scatter(x, data.batch, dim=0, reduce='mean')

        # Step 3: Pass through MLP for classification
        out = self.mlp(graph_feature)
        return out


# non-rotation-equivariant transformer (so just permutations)
# Transformer Model

class TransformerModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim, num_heads, num_layers, output_dim, num_inputs_without_pos = 1, max_nodes=29,use_all_encoding =False):
        super().__init__()
        self.max_nodes = max_nodes
        self.use_all_encoding = use_all_encoding
        self.num_inputs_without_pos = num_inputs_without_pos
        # Embedding layers for nodes
        # Categorical
        self.node_embedding = nn.Embedding(num_node_features, hidden_dim) # this is the number of atom types + other categorical data
        # For atomic positions
        self.pos_embedding = nn.Linear(3, hidden_dim)

        # additional encoding, not really sure what this is doing atm
        if self.use_all_encoding:
            self.all_encoding = nn.Parameter(torch.zeros(max_nodes, hidden_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, 
            nhead=num_heads, 
            dim_feedforward=hidden_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Pooling layer
        self.pooling = nn.Linear(hidden_dim, hidden_dim)

        # Final regression layer
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, data, **kwargs): #x, mask=None, **kwargs):
        """
        x: [batch_size, max_nodes, num_node_features]
        mask: [batch_size, max_nodes], binary mask for padding
        """
        ### TODO: need to change this not to be hardcoded!!!
        x, mask = preprocess(data, dataset_type=kwargs['dataset_type'], max_nodes=self.max_nodes, filter_mol=kwargs['filter_mol'])

        # Embed nodes
        # Note this currently assumes only atom types are passed in
        categorical_feats = x[:,:,0].long()
        pos_feats = x[:,:,self.num_inputs_without_pos:] 
        if self.use_all_encoding:
            x = self.node_embedding(categorical_feats) + self.pos_embedding(pos_feats) + self.all_encoding[:x.size(1)]
        else:
            x = self.node_embedding(categorical_feats) + self.pos_embedding(pos_feats)

        # Apply transformer
        out = self.transformer(x, src_key_padding_mask=~mask if mask is not None else None)

        # Pooling (mean over all nodes in the molecule)
        pooled = torch.sum(out * mask.unsqueeze(-1), dim=1) / mask.sum(dim=1, keepdim=True)

        # Output prediction
        return self.output_layer(pooled)

    def get_embedding(self, data, **kwargs):
        x, mask = preprocess(data, dataset_type=kwargs['dataset_type'], max_nodes=29, filter_mol=kwargs['filter_mol'])
        categorical_feats = x[:, :, 0].long()
        pos_feats = x[:, :, self.num_inputs_without_pos:]

        if self.use_all_encoding:
            x = self.node_embedding(categorical_feats) + self.pos_embedding(pos_feats) + self.all_encoding[:x.size(1)]
        else:
            x = self.node_embedding(categorical_feats) + self.pos_embedding(pos_feats)

        # Apply transformer
        out = self.transformer(x, src_key_padding_mask=~mask if mask is not None else None)

        # Pooling (mean over valid nodes)
        pooled = torch.sum(out * mask.unsqueeze(-1), dim=1) / mask.sum(dim=1, keepdim=True)

        return pooled  # shape: [batch_size, hidden_dim]

class SO3FrameAveragedTransformer(nn.Module):
    def __init__(self, base_model: nn.Module, n_rotations: int = 10, output_is_vector: bool = False):
        """
        Wraps a TransformerModel and applies SO(3) frame averaging.

        Args:
            base_model (nn.Module): An instance of TransformerModel.
            n_rotations (int): Number of random rotations.
            output_is_vector (bool): Whether the model outputs vectors.
        """
        super().__init__()
        self.model = base_model
        self.n_rotations = n_rotations
        self.output_is_vector = output_is_vector

    def forward(self, data, **kwargs):
        """
        Args:
            data: Dictionary containing `x`, where x[:,:,1:] has positions.
        Returns:
            Averaged model output after rotating inputs and (if needed) unrotating outputs.
        """
        device = next(self.model.parameters()).device
        data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in data.items()}

        # Get x, mask using preprocess (which your model uses)
        x, mask = preprocess(data, dataset_type=kwargs['dataset_type'], max_nodes=self.model.max_nodes, filter_mol=kwargs['filter_mol'])
        x_orig = x.clone()

        # Indices where 3D positions are stored
        pos_start = self.model.num_inputs_without_pos
        pos_end = pos_start + 3

        # Sample rotations
        rotations = utils.sample_random_rotations(self.n_rotations).to(device)

        outputs = []

        for rot in rotations:
            # Clone the input
            x_rot = x_orig.clone()
            pos = x_rot[:, :, pos_start:pos_end]  # shape (B, N, 3)
            pos_rot = torch.matmul(pos, rot.T)  # apply R^T to positions
            x_rot[:, :, pos_start:pos_end] = pos_rot

            # Clone data and set rotated x
            data_rot = copy.deepcopy(data)
            data_rot['x'] = x_rot

            # Forward pass through base model
            out = self.model(data_rot, **kwargs)  # shape: (B, output_dim)

            if self.output_is_vector:
                # Rotate back the output vectors
                out = torch.matmul(out, rot)

            outputs.append(out)

        outputs = torch.stack(outputs, dim=0)  # (n_rot, B, output_dim)
        avg_output = outputs.mean(dim=0)       # (B, output_dim)

        return avg_output

class TransformerforRotModel(nn.Module):
    def __init__(self, make_rotation=False, **kwargs): #num_node_features, hidden_dim, num_heads, num_layers, output_dim, num_inputs_without_pos = 1, max_nodes=29,use_all_encoding =False):
        super().__init__()
        self.transformer = TransformerModel(**kwargs)
        self.make_rotation = make_rotation

    def forward(self, data, **kwargs):
        rots = self.transformer(data, **kwargs) # batch x 9
        if self.make_rotation:
            rots = safe_gram_schmidt_to_rotation_matrix(rots[:, 0:6], eps=1e-6)
        return rots.reshape(rots.shape[0], -1)


class TransformerForMD17Rotation(torch.nn.Module):
    def __init__(self, num_node_features, hidden_dim, num_heads, num_layers, output_dim, max_nodes):
        super(TransformerForMD17Rotation, self).__init__()
        self.embedding = torch.nn.Linear(num_node_features, hidden_dim)  # Embedding the node features (positions)
        self.transformer_layers = torch.nn.ModuleList([
            torch.nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads) for _ in range(num_layers)
        ])
        self.fc_out = torch.nn.Linear(hidden_dim, output_dim)  # Output layer for classification (rotation or not)
        self.max_nodes = max_nodes

    def forward(self, node_features, **kwargs):
        # Data contains 'pos' (positions of atoms)
        
        #x = data.pos  # Shape: [num_nodes, 3] (3D positions of atoms)
        x = self.embedding(node_features)  # Embed the positions into a hidden space
        
        # Reshape for Transformer: (max_nodes, batch_size, hidden_dim)
        x = x.view(self.max_nodes, -1, x.size(-1)) 
        
        # Pass through transformer layers
        for layer in self.transformer_layers:
            x = layer(x)
        
        x = x.mean(dim=0)  # Aggregate over nodes (atoms)

        # Final output layer with sigmoid to produce probability (between 0 and 1)
        output = self.fc_out(x)
        return torch.sigmoid(output)  # Sigmoid to ensure output is between 0 and 1
    

class TransformerForModelNet(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_layers, output_dim, use_all_encoding=False):
        super().__init__()
        self.use_all_encoding = use_all_encoding
        
        # only use position embedding
        self.pos_embedding = nn.Linear(3, hidden_dim)     

        # optional additional encoding
        if self.use_all_encoding:
            self.all_encoding = nn.Parameter(torch.zeros(1024, hidden_dim))  # fixed num_sample_points points

        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, 
            nhead=num_heads, 
            dim_feedforward=hidden_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # pooling layer??
        # self.pooling = nn.Linear(hidden_dim, hidden_dim)

        # output layer
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, data, **kwargs):
        """
        data: Data object containing 'pos' field, shape [batch_size, num_sample_points, 3]
        """
        # directly use position data, no need to preprocess
        pos_feats, mask = preprocess(data, dataset_type='modelnet') # [batch_size, num_sample_points, 3]
        if self.use_all_encoding:
            x = self.pos_embedding(pos_feats) + self.all_encoding  # broadcast to batch dimension
        else:
            x = self.pos_embedding(pos_feats)

        # apply Transformer
        out = self.transformer(x, src_key_padding_mask=~mask if mask is not None else None)  # no need for mask, all points are valid
        # print('out', out.shape)
        # pooling (average over all points)
        pooled = out.mean(dim=1)  # [batch_size, hidden_dim]
        # print('pooled', pooled.shape)
        # output prediction
        return self.output_layer(pooled)

    def get_embedding(self, data, **kwargs):
        # directly use position data
        pos_feats = data.pos
        # position feature embedding
        if self.use_all_encoding:
            x = self.pos_embedding(pos_feats) + self.all_encoding
        else:
            x = self.pos_embedding(pos_feats)

        # apply Transformer
        out = self.transformer(x)

        # pooling
        pooled = out.mean(dim=1)

        return pooled  # shape: [batch_size, hidden_dim]

    
# Define a small neural network
class SimpleNN(nn.Module):
    def __init__(self, hidden_dim=4):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(2, hidden_dim)  # 2 input features -> 4 hidden neurons
        self.fc2 = nn.Linear(hidden_dim, 2)  # 4 hidden neurons -> 2 output classes

    def forward(self, x, **kwargs):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)  # No softmax; handled by CrossEntropyLoss
        return x

class CustomMLP(nn.Module):
    def __init__(self, input_features, output_features, hidden_features, num_layers, dropout=0.1):
        """
        Initializes a customizable MLP.
        
        Args:
            input_features (int): Number of input features.
            output_features (int): Number of output features.
            hidden_features (int): Number of hidden features in each layer.
            num_layers (int): Number of hidden layers.
            dropout (float): Dropout rate (default: 0.1).
        """
        super().__init__()
        
        # Create a list to hold the layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(input_features, hidden_features))
        layers.append(nn.LayerNorm(hidden_features))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_features, hidden_features))
            layers.append(nn.LayerNorm(hidden_features))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        
        # Output layer
        layers.append(nn.Linear(hidden_features, output_features))
        
        # Combine layers into a sequential block
        self.mlp = nn.Sequential(*layers)

    def forward(self, x, **kwargs):
        if type(x) != torch.Tensor:
            print('x', type(x))
            temp = preprocess_data_for_MLP(x)
        else:
            temp = x
        return self.mlp(temp.view(temp.shape[0], -1))

import torch.nn.functional as F

import torch.nn.functional as F

class C2LiftLayer(nn.Module):
    def __init__(self, n_input=3, n_output=256):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(n_input, n_output),
            nn.ReLU(inplace=True),
        )   
    
    def forward(self, x, **kwargs):
        x0 = x
        x_ = x[:, :2]
        x_flip = 1 - x[:, 2]
        x1 = torch.cat([x_, x_flip.unsqueeze(1)], dim=1)
        feature1 = self.fc1(x0)
        feature2 = self.fc1(x1)
        return torch.cat([feature1, feature2], dim=1)
    
class C2EMLP(nn.Module):
    def __init__(self, n_input=256, n_output=256, activation="relu"):
        super().__init__()
        self.n_input = n_input
        self.n_output = n_output
        self.activation = activation
        
        self.weights = torch.nn.Parameter(torch.zeros(1, n_output, n_input * 2))
        torch.nn.init.kaiming_uniform_(self.weights.data)
        self.bias = torch.nn.Parameter(torch.randn(n_output))
    
    def generate_weights_bank(self):
        weights_bank = []
        W0 = self.weights.clone()
        weights_bank.append(W0)
        
        W1 = torch.zeros_like(W0)
        W1[:, :, :self.n_input] = W0[:, :, self.n_input:]
        W1[:, :, self.n_input:] = W0[:, :, :self.n_input]
        weights_bank.append(W1)
        
        return torch.stack(weights_bank, dim=0)

    def forward(self, x, **kwargs):
        batch_size = x.shape[0]
        
        weights_bank = self.generate_weights_bank()  # [2, 1, n_output, n_input*2]
        weights = weights_bank.squeeze(1)  # [2, n_output, n_input*2]
        
        weights = weights.reshape(2 * self.n_output, self.n_input * 2)
        
        bias = self.bias.repeat(2)
        
        x = F.linear(x, weights, bias)
        # print(x.shape)
        
        x = x.view(batch_size, 2, self.n_output)

        if self.activation:
            if self.activation == "leaky_relu":
                x = F.leaky_relu(x)
            elif self.activation == "relu":
                x = F.relu(x)
            elif self.activation == "elu":
                x = F.elu(x)
            else:
                raise ValueError('Wrong Activation Function')
        
        return x.reshape(batch_size, -1)
    
class C2EMLP_for_swiss_roll(nn.Module):
    def __init__(self, in_dim=3, hidden_dim = 256, output_dim=2, activation="relu"):
        super().__init__()
        self.n_input = in_dim
        self.n_output = hidden_dim
        self.activation = activation
        
        self.network = nn.Sequential(C2LiftLayer(n_input=in_dim, n_output=hidden_dim),
                                C2EMLP(n_input=hidden_dim, n_output=hidden_dim, activation=activation),
                                C2EMLP(n_input=hidden_dim, n_output=hidden_dim, activation=activation),
                                C2EMLP(n_input=hidden_dim, n_output=output_dim, activation=activation))
    
    def forward(self, x,**kwargs):
        return self.network(x)  # output shape: [batch_size, n_classes]

class C2Canonicalization(nn.Module):
    def __init__(self, n_input=3, hidden_dim=512, beta=1.0, output_g=False):
        super().__init__()
        self.beta = beta
        self.network = nn.Sequential(C2LiftLayer(n_input=n_input, n_output=hidden_dim),
                                # C2EMLP(n_input=hidden_dim, n_output=hidden_dim, activation="relu"),
                                C2EMLP(n_input=hidden_dim, n_output=1, activation=None))
        self.output_g = output_g

    def forward(self, x, output_g=None):

        x_ = x[:, :2]
        x_flip = 1 - x[:, 2]
        x_trans = torch.cat([x_, x_flip.unsqueeze(1)], dim=1)

        x_3 = self.network(x)

        group_one_hot = F.one_hot(x_3.argmax(dim=-1), num_classes=x_3.shape[-1])
        group_soft = F.softmax(self.beta * x_3, dim=-1)
        if self.training:
            flip_group = group_one_hot + group_soft - group_soft.detach()
            # flip_group = group_soft
        else:
            flip_group = group_one_hot
        # print("flip_group.shape", flip_group.shape)
        # print("flip_group", flip_group)
        
        x_canonical = x * flip_group[:, 0].unsqueeze(1) + x_trans * flip_group[:, 1].unsqueeze(1)
        
        if output_g is None: # can override the input
            do_output_g = self.output_g
        else:
            do_output_g = output_g
        if do_output_g:
            return torch.abs(x[:,-1:] - x_canonical[:, -1:])
        else:
            return x_canonical

def safe_gram_schmidt_to_rotation_matrix(vectors, eps=1e-6):
    """
    Converts a batch of 6D vectors to 3x3 rotation matrices using Gram-Schmidt,
    with fallback for degenerate or linearly dependent cases.

    Args:
        vectors: Tensor of shape (B, 6)
        eps: Small threshold to detect degenerate vectors

    Returns:
        Tensor of shape (B, 3, 3)
    """
    B = vectors.shape[0]
    assert vectors.shape[1] == 6, "Input must be of shape (B, 6)"
    v1 = vectors[:, 0:3]
    v2 = vectors[:, 3:6]

    # Step 1: Normalize v1 to get b1
    v1_norm = v1.norm(dim=1, keepdim=True).clamp(min=eps)
    b1 = v1 / v1_norm

    # Step 2: Orthogonalize v2 w.r.t. b1 to get b2
    proj = (v2 * b1).sum(dim=1, keepdim=True) * b1
    u2 = v2 - proj
    u2_norm = u2.norm(dim=1, keepdim=True)

    # Identify degenerate cases where u2 is near-zero
    # degenerate_mask = (u2_norm < eps).squeeze(1)  # shape (B,)
    
    # # Default fallback: pick a vector orthogonal to b1 (deterministic)
    # # Use [1,0,0] unless b1 is close to [1,0,0], then use [0,1,0]
    # fallback = torch.tensor([1.0, 0.0, 0.0], device=vectors.device, dtype=vectors.dtype).expand(B, 3)
    # alt_fallback = torch.tensor([0.0, 1.0, 0.0], device=vectors.device, dtype=vectors.dtype).expand(B, 3)

    # dot_with_fallback = (b1 * fallback).sum(dim=1, keepdim=True)
    # fallback = torch.where(dot_with_fallback.abs() > 0.99, alt_fallback, fallback)

    # fallback_u2 = fallback - (fallback * b1).sum(dim=1, keepdim=True) * b1
    # fallback_b2 = F.normalize(fallback_u2, dim=1)

    # Use fallback where degenerate
    b2 = F.normalize(u2, dim=1)
    # b2[degenerate_mask] = fallback_b2[degenerate_mask]

    # Step 3: b3 = b1 × b2 (ensures right-handedness)
    b3 = torch.cross(b1, b2, dim=1)

    # Stack into rotation matrix
    R = torch.stack([b1, b2, b3], dim=2)  # shape (B, 3, 3)
    return R

### Equivariant models, based on https://github.com/prescient-design/e3tools/tree/main/examples/models ###

class E3ConvNet(nn.Module):
    """A simple E(3)-equivariant convolutional neural network, similar to NequIP."""

    def __init__(
        self,
        irreps_out: Union[str, e3nn.o3.Irreps],
        irreps_hidden: Union[str, e3nn.o3.Irreps],
        irreps_sh: Union[str, e3nn.o3.Irreps],
        num_layers: int,
        edge_attr_dim: int,
        atom_type_embedding_dim: int,
        num_atom_types: int,
        max_radius: float,
        **kwargs
    ):
        super().__init__()

        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_hidden = o3.Irreps(irreps_hidden)
        self.irreps_sh = o3.Irreps(irreps_sh)
        self.num_layers = num_layers
        self.edge_attr_dim = edge_attr_dim
        self.max_radius = max_radius
        self.for_canon = kwargs.get('for_canon', False)

        self.sh = o3.SphericalHarmonics(
            irreps_out=self.irreps_sh, normalize=True, normalization="component"
        )
        self.radial_edge_attr_dim = self.edge_attr_dim
        #self.bonded_edge_attr_dim, self.radial_edge_attr_dim = (
        #    self.edge_attr_dim // 2,
        #    (self.edge_attr_dim + 1) // 2,
        #)
        #self.embed_bondedness = nn.Embedding(2, self.bonded_edge_attr_dim)

        self.atom_embedder = nn.Embedding(
            num_embeddings=num_atom_types,
            embedding_dim=atom_type_embedding_dim,
        )
        self.initial_linear = o3.Linear(
            f"{atom_type_embedding_dim}x0e", self.irreps_hidden
        )

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(
                e3tools.nn.ConvBlock(
                    irreps_in=self.irreps_hidden,
                    irreps_out=self.irreps_hidden,
                    irreps_sh=self.irreps_sh,
                    edge_attr_dim=self.edge_attr_dim,
                )
            )
        self.output_head = e3tools.nn.EquivariantMLP(
            irreps_in=self.irreps_hidden,
            irreps_out=self.irreps_out,
            irreps_hidden_list=[self.irreps_hidden],
        )
        if 'reshape_out' in kwargs.keys():
            self.reshape_out = kwargs['reshape_out']
        else:
            self.reshape_out = False

    @torch._dynamo.disable
    def process_input(data):
        if "edge_index" not in data or data["edge_index"] is None:
            data["edge_index"] = e3tools.radius_graph(data["pos"], 5.0, data.get("batch", None))
        return data


    def forward(
        self,
        data: Dict[str, torch.Tensor],
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """Forward pass of the E3Conv model."""

        #data = self.process_input(data)
        # Extract edge attributes.
        pos = data["pos"]
        edge_index = data["edge_index"]

        # don't use bond_mask for now for simplicity
        #bond_mask = data["bond_mask"]

        src, dst = edge_index
        edge_vec = pos[src] - pos[dst]
        edge_sh = self.sh(edge_vec)

        # Compute edge attributes.
        #bonded_edge_attr = self.embed_bondedness(bond_mask)
        radial_edge_attr = e3nn.math.soft_one_hot_linspace(
            edge_vec.norm(dim=1),
            0.0,
            self.max_radius,
            self.radial_edge_attr_dim,
            basis="gaussian",
            cutoff=True,
        )
        edge_attr = radial_edge_attr #torch.cat((bonded_edge_attr, radial_edge_attr), dim=-1)

        # Compute node attributes.
        node_attr = self.atom_embedder(data["z"])
        node_attr = self.initial_linear(node_attr)
        # Perform message passing.
        for layer in self.layers:
            node_attr = layer(node_attr, edge_index, edge_attr, edge_sh)

        # Pool over nodes.
        global_attr = e3tools.scatter(
            node_attr,
            index=data["batch"],
            dim=0,
            dim_size=data.num_graphs,
        )

        global_attr = self.output_head(global_attr)

        # print('global_attr', global_attr.shape)
        # print('global_attr first few', global_attr[0:3])
        if self.for_canon:
            gs_out = safe_gram_schmidt_to_rotation_matrix(global_attr)
            # reshape?
            if self.reshape_out:
                return gs_out.reshape(gs_out.shape[0], -1)
            return gs_out
        else:
            return global_attr
    

class E3ConvNetforModelNet(nn.Module):
    """A simple E(3)-equivariant convolutional neural network, similar to NequIP."""

    def __init__(
        self,
        irreps_out: Union[str, e3nn.o3.Irreps],
        irreps_hidden: Union[str, e3nn.o3.Irreps],
        irreps_sh: Union[str, e3nn.o3.Irreps],
        num_layers: int,
        edge_attr_dim: int,
        atom_type_embedding_dim: int,
        num_atom_types: int,
        max_radius: float,
        **kwargs
    ):
        super().__init__()

        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_hidden = o3.Irreps(irreps_hidden)
        self.irreps_sh = o3.Irreps(irreps_sh)
        self.num_layers = num_layers
        self.edge_attr_dim = edge_attr_dim
        self.max_radius = max_radius
        self.for_canon = kwargs.get('for_canon', False)

        self.sh = o3.SphericalHarmonics(
            irreps_out=self.irreps_sh, normalize=True, normalization="component"
        )
        self.radial_edge_attr_dim = self.edge_attr_dim
        #self.bonded_edge_attr_dim, self.radial_edge_attr_dim = (
        #    self.edge_attr_dim // 2,
        #    (self.edge_attr_dim + 1) // 2,
        #)
        #self.embed_bondedness = nn.Embedding(2, self.bonded_edge_attr_dim)

        self.atom_embedder = nn.Embedding(
            num_embeddings=num_atom_types,
            embedding_dim=atom_type_embedding_dim,
        )
        self.initial_linear = o3.Linear(
            f"{atom_type_embedding_dim}x0e", self.irreps_hidden
        )

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(
                e3tools.nn.ConvBlock(
                    irreps_in=self.irreps_hidden,
                    irreps_out=self.irreps_hidden,
                    irreps_sh=self.irreps_sh,
                    edge_attr_dim=self.edge_attr_dim,
                )
            )
        self.output_head = e3tools.nn.EquivariantMLP(
            irreps_in=self.irreps_hidden,
            irreps_out=self.irreps_out,
            irreps_hidden_list=[self.irreps_hidden],
        )



    def process_input(self, data):
        data.edge_index = knn_graph(data.pos, k=16, batch=data.batch)
        data.z = torch.ones(data.pos.shape[0], dtype=torch.long).to(data.pos.device).detach()
        return data

    def forward(
        self,
        data: Dict[str, torch.Tensor],
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """Forward pass of the E3Conv model."""

        data = self.process_input(data)

        # Extract edge attributes.
        pos = data["pos"]
        edge_index = data["edge_index"]

        # don't use bond_mask for now for simplicity
        #bond_mask = data["bond_mask"]

        src, dst = edge_index
        edge_vec = pos[src] - pos[dst]
        edge_sh = self.sh(edge_vec)

        # Compute edge attributes.
        #bonded_edge_attr = self.embed_bondedness(bond_mask)
        radial_edge_attr = e3nn.math.soft_one_hot_linspace(
            edge_vec.norm(dim=1),
            0.0,
            self.max_radius,
            self.radial_edge_attr_dim,
            basis="gaussian",
            cutoff=True,
        )
        edge_attr = radial_edge_attr #torch.cat((bonded_edge_attr, radial_edge_attr), dim=-1)

        # Compute node attributes.
        node_attr = self.atom_embedder(data["z"])
        node_attr = self.initial_linear(node_attr)
        # Perform message passing.
        for layer in self.layers:
            node_attr = layer(node_attr, edge_index, edge_attr, edge_sh)

        # Pool over nodes.
        global_attr = e3tools.scatter(
            node_attr,
            index=data["batch"],
            dim=0,
            dim_size=data.num_graphs,
        )

        global_attr = self.output_head(global_attr)

        if self.for_canon:
            gs_out = safe_gram_schmidt_to_rotation_matrix(global_attr)
            # reshape?
            return gs_out
        else:
            return global_attr






class NN_swiss_roll(nn.Module):
    def __init__(self, in_dim=3, hidden_dim=128, output_dim=2):
        super(NN_swiss_roll, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, **kwargs):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class C2AveNN_swiss_roll(nn.Module):
    def __init__(self, in_dim=3, hidden_dim=128, output_dim=2):
        super(C2AveNN_swiss_roll, self).__init__()
        self.base_model = NN_swiss_roll(in_dim, hidden_dim, output_dim)
        
    def forward(self, x, **kwargs):
        x_id = x
        x_ = x[:,:2]
        x_flip = (1 - x[:,2])
        x_flip = torch.cat([x_, x_flip.unsqueeze(1)], dim=1)
        out_id = self.base_model(x_id)
        out_g = self.base_model(x_flip)
        
        out_avg = 0.5 * (out_id + out_g)
        return out_avg

### Graphormer model ###
@torch.jit.script
def softmax_dropout(input, dropout_prob: float, is_training: bool):
    return F.dropout(F.softmax(input, -1), dropout_prob, is_training)

class SelfMultiheadAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        scaling_factor=1,
    ):
        super().__init__()
        self.embed_dim = embed_dim

        self.num_heads = num_heads
        self.dropout = dropout

        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
        self.scaling = (self.head_dim * scaling_factor) ** -0.5

        self.in_proj: Callable[[Tensor], Tensor] = nn.Linear(
            embed_dim, embed_dim * 3, bias=bias
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def forward(
        self,
        query: Tensor,
        attn_bias: Tensor = None,
    ) -> Tensor:
        n_node, n_graph, embed_dim = query.size()
        q, k, v = self.in_proj(query).chunk(3, dim=-1)

        _shape = (-1, n_graph * self.num_heads, self.head_dim)
        q = q.contiguous().view(_shape).transpose(0, 1) * self.scaling
        k = k.contiguous().view(_shape).transpose(0, 1)
        v = v.contiguous().view(_shape).transpose(0, 1)
        attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_bias
        attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)

        attn = torch.bmm(attn_probs, v)
        attn = attn.transpose(0, 1).contiguous().view(n_node, n_graph, embed_dim)
        attn = self.out_proj(attn)
        return attn


class Graphormer3DEncoderLayer(nn.Module):
    """
    Implements a Graphormer-3D Encoder Layer.
    """

    def __init__(
        self,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
    ) -> None:
        super().__init__()

        # Initialize parameters
        self.embedding_dim = embedding_dim
        self.num_attention_heads = num_attention_heads
        self.attention_dropout = attention_dropout

        self.dropout = dropout
        self.activation_dropout = activation_dropout

        self.self_attn = SelfMultiheadAttention(
            self.embedding_dim,
            num_attention_heads,
            dropout=attention_dropout,
        )
        # layer norm associated with the self attention layer
        self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
        self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
        self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
        self.final_layer_norm = nn.LayerNorm(self.embedding_dim)

    def forward(
        self,
        x: Tensor,
        attn_bias: Tensor = None,
    ):
        residual = x
        x = self.self_attn_layer_norm(x)
        x = self.self_attn(
            query=x,
            attn_bias=attn_bias,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x

        residual = x
        x = self.final_layer_norm(x)
        x = F.gelu(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        return x

@torch.jit.script
def gaussian(x, mean, std):
    pi = 3.14159
    a = (2*pi) ** 0.5
    return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)

class GaussianLayer(nn.Module):
    def __init__(self, K=128, edge_types=1024):
        super().__init__()
        self.K = K
        self.means = nn.Embedding(1, K)
        self.stds = nn.Embedding(1, K)
        self.mul = nn.Embedding(edge_types, 1)
        self.bias = nn.Embedding(edge_types, 1)
        nn.init.uniform_(self.means.weight, 0, 3)
        nn.init.uniform_(self.stds.weight, 0, 3)
        nn.init.constant_(self.bias.weight, 0)
        nn.init.constant_(self.mul.weight, 1)

    def forward(self, x, edge_types):
        mul = self.mul(edge_types)
        bias = self.bias(edge_types)
        x = mul * x.unsqueeze(-1) + bias
        x = x.expand(-1, -1, -1, self.K)
        mean = self.means.weight.float().view(-1)
        std = self.stds.weight.float().view(-1).abs() + 1e-5
        return gaussian(x.float(), mean, std).type_as(self.means.weight)

class RBF(nn.Module):
    def __init__(self, K, edge_types):
        super().__init__()
        self.K = K
        self.means = nn.parameter.Parameter(torch.empty(K))
        self.temps = nn.parameter.Parameter(torch.empty(K))
        self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
        self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
        nn.init.uniform_(self.means, 0, 3)
        nn.init.uniform_(self.temps, 0.1, 10)
        nn.init.constant_(self.bias.weight, 0)
        nn.init.constant_(self.mul.weight, 1)

    def forward(self, x: Tensor, edge_types):
        mul = self.mul(edge_types)
        bias = self.bias(edge_types)
        x = mul * x.unsqueeze(-1) + bias
        mean = self.means.float()
        temp = self.temps.float().abs()
        return ((x - mean).square() * (-temp)).exp().type_as(self.means)


class NonLinear(nn.Module):
    def __init__(self, input, output_size, hidden=None):
        super(NonLinear, self).__init__()
        if hidden is None:
            hidden = input
        self.layer1 = nn.Linear(input, hidden)
        self.layer2 = nn.Linear(hidden, output_size)

    def forward(self, x, **kwargs):
        x = F.gelu(self.layer1(x))
        x = self.layer2(x)
        return x


class NodeTaskHead(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
        self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
        self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads
        self.scaling = (embed_dim // num_heads) ** -0.5
        self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
        self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
        self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)

    def forward(
        self,
        query: Tensor,
        attn_bias: Tensor,
        delta_pos: Tensor,
    ) -> Tensor:
        bsz, n_node, _ = query.size()
        q = (
            self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
            * self.scaling
        )
        k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
        v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
        attn = q @ k.transpose(-1, -2)  # [bsz, head, n, n]
        attn_probs = softmax_dropout(
            attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training
        ).view(bsz, self.num_heads, n_node, n_node)
        rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(
            attn_probs
        )  # [bsz, head, n, n, 3]
        rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3)
        x = rot_attn_probs @ v.unsqueeze(2)  # [bsz, head , 3, n, d]
        x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1)
        f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1)
        f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1)
        f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1)
        cur_force = torch.cat([f1, f2, f3], dim=-1).float()
        return cur_force


class Graphormer3D(nn.Module):
    def __init__(self,
                 num_inputs_without_pos,
                 num_outputs,
                 embed_dim=128,
                 ffn_embed_dim=128,
                 attention_heads=32,
                 input_dropout=0.0,
                 dropout=0.1,
                 attention_dropout=0.1,
                 activation_dropout=0.0,
                 blocks=4,
                 layers=8,
                 num_kernel=32,
                 #num_outputs=1,  # this should match len(target_names)
                 #num_inputs_without_pos=11,
                 use_absolute_pos=True, #for energies, the model is invariant if we don't use absolute positions
                 **kwargs):
        super().__init__()

        # not sure this is correct atm?
        self.atom_types = 64
        self.edge_types = 64 * 64
        self.input_dropout = input_dropout
        self.embed_dim = embed_dim
        self.attention_heads = attention_heads
        self.blocks = blocks
        self.layers_num = layers  # avoid shadowing `self.layers`
        self.num_outputs = num_outputs
        self.num_inputs_without_pos = num_inputs_without_pos

        self.atom_encoder = nn.Embedding(self.atom_types, embed_dim, padding_idx=0)
        self.tag_encoder = nn.Embedding(num_inputs_without_pos, embed_dim)
        self.pos_embedding = nn.Linear(3, embed_dim)
        self.use_absolute_pos = use_absolute_pos

        self.layers = nn.ModuleList([
            Graphormer3DEncoderLayer(
                embed_dim,
                ffn_embed_dim,
                num_attention_heads=attention_heads,
                dropout=dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
            )
            for _ in range(layers)
        ])

        self.final_ln = nn.LayerNorm(embed_dim)
        self.engergy_proj = NonLinear(embed_dim, num_outputs)
        self.energe_agg_factor = nn.Embedding(3, 1)
        nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01)

        self.gbf = GaussianLayer(num_kernel, self.edge_types)
        self.bias_proj = NonLinear(num_kernel, attention_heads)
        self.edge_proj = nn.Linear(num_kernel, embed_dim)
        self.node_proc = NodeTaskHead(embed_dim, attention_heads)


    def set_num_updates(self, num_updates):
        self.num_updates = num_updates
        return super().set_num_updates(num_updates)

    def forward(self, data, **kwargs):
        x, mask = preprocess(data, dataset_type=kwargs['dataset_type'], max_nodes=29, filter_mol=kwargs['filter_mol'])
        # Embed nodes
        # Note this currently assumes only atom types are passed in
        atoms = x[:,:,0].long()
        pos = x[:,:,1:].float()
        padding_mask = atoms.eq(0)

        n_graph, n_node = atoms.size()
        delta_pos = pos.unsqueeze(1) - pos.unsqueeze(2)
        dist: Tensor = delta_pos.norm(dim=-1)
        delta_pos /= dist.unsqueeze(-1) + 1e-5
        edge_type = atoms.view(n_graph, n_node, 1) * self.atom_types + atoms.view(
            n_graph, 1, n_node
        ).long()
        gbf_feature = self.gbf(dist.long(), edge_type.long())
        edge_features = gbf_feature.masked_fill(
            padding_mask.unsqueeze(1).unsqueeze(-1), 0.0
        )
        
        graph_node_feature = (
            self.atom_encoder(atoms)
            + self.edge_proj(edge_features.sum(dim=-2))
            + self.pos_embedding(pos)
        )

        # ===== MAIN MODEL =====
        output = F.dropout(
            graph_node_feature, p=self.input_dropout, training=self.training
        )
        output = output.transpose(0, 1).contiguous()

        graph_attn_bias = self.bias_proj(gbf_feature).permute(0, 3, 1, 2).contiguous()
        graph_attn_bias.masked_fill_(
            padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
        )

        graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
        for _ in range(self.blocks):
            for enc_layer in self.layers:
                output = enc_layer(output, attn_bias=graph_attn_bias)

        output = self.final_ln(output)
        output = output.transpose(0, 1)

        eng_output = F.dropout(output, p=0.1, training=self.training)
        eng_output = (
            self.engergy_proj(eng_output)
        )
        # not sure about this?
        eng_output = eng_output.mean(dim=1)
        #node_output = self.node_proc(output, graph_attn_bias, delta_pos)

        #node_target_mask = output_mask.unsqueeze(-1)
        return eng_output#, node_output, node_target_mask


class Graphormer3D_SO3Ave(Graphormer3D):
    def __init__(self, num_inputs_without_pos, num_outputs, n_rotations: int = 10, output_is_vector: bool = False):
        """
        Wraps a Graphormer3D model and applies SO(3) frame averaging at inference time.

        Args:
            base_model (nn.Module): The original Graphormer3D model.
            n_rotations (int): Number of random SO(3) rotations to average over.
            output_is_vector (bool): If True, treats the output as a 3D vector and rotates it back before averaging.
        """
        super().__init__(num_inputs_without_pos=num_inputs_without_pos, num_outputs=num_outputs)
        self.n_rotations = n_rotations
        self.output_is_vector = output_is_vector

    def forward(self, data, **kwargs):
        """
        Forward pass with SO(3) frame averaging.
        Args:
            data: Dictionary or data object containing `x`, where x[:,:,1:4] are 3D positions.
            kwargs: Additional args passed to the base model.
        Returns:
            Tensor of shape (B, output_dim)
        """
        device = next(super().parameters()).device

        # Sample random SO(3) rotations
        rotations = utils.sample_random_rotations(self.n_rotations).to(device)  # (n_rot, 3, 3)

        outputs = []
        
        for rot in rotations:
            # Rotate positions
            pos_rot = torch.matmul(data.pos, rot.T)  # (total_nodes, 3)

            data_rot = Batch()
            for key in data.keys():
                if key == 'pos':
                    data_rot[key] = pos_rot
                else:
                    data_rot[key] = data[key]

            # Forward pass
            out = super().forward(data_rot, **kwargs)

            # Rotate output back if it's a vector
            if self.output_is_vector:
                if out.ndim == 2:  # (B, 3)
                    out = torch.matmul(out, rot)
                elif out.ndim == 3:  # (B, N, 3)
                    out = torch.matmul(out, rot)

            outputs.append(out)

        outputs = torch.stack(outputs, dim=0)  # (n_rot, B, ...)
        return outputs.mean(dim=0)


EPS = 1e-6

class VNLinear(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VNLinear, self).__init__()
        self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
    
    def forward(self, x, **kwargs):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        x_out = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
        return x_out


class VNBilinear(nn.Module):
    def __init__(self, in_channels1, in_channels2, out_channels):
        super(VNBilinear, self).__init__()
        self.map_to_feat = nn.Bilinear(in_channels1, in_channels2, out_channels, bias=False)
    
    def forward(self, x, labels):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        labels = labels.repeat(1, x.shape[2], 1).float()
        x_out = self.map_to_feat(x.transpose(1,-1), labels).transpose(1,-1)
        return x_out


class VNSoftplus(nn.Module):
    def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.0):
        super(VNSoftplus, self).__init__()
        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
        self.negative_slope = negative_slope
    
    def forward(self, x, **kwargs):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdim=True)
        angle_between = torch.acos(dotprod / (torch.norm(x, dim=2, keepdim=True) * torch.norm(d, dim=2, keepdim=True) + EPS))
        # create a smooth scale between 0 and 1 based on the angle between x and d
        mask = torch.cos(angle_between / 2) ** 2
        d_norm_sq = (d*d).sum(2, keepdim=True)
        x_out = self.negative_slope * x + (1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d))
        return x_out


class VNLeakyReLU(nn.Module):
    def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2):
        super(VNLeakyReLU, self).__init__()
        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
        self.negative_slope = negative_slope
    
    def forward(self, x, **kwargs):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdim=True)
        mask = (dotprod >= 0).float()
        d_norm_sq = (d*d).sum(2, keepdim=True)
        x_out = self.negative_slope * x + (1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d))
        return x_out


class VNLinearLeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, negative_slope=0.2):
        super(VNLinearLeakyReLU, self).__init__()
        self.dim = dim
        self.negative_slope = negative_slope
        
        self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
        self.batchnorm = VNBatchNorm(out_channels, dim=dim)
        
        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False)
    
    def forward(self, x, **kwargs):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        # Linear
        p = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
        # BatchNorm
        p = self.batchnorm(p)
        # LeakyReLU
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (p*d).sum(2, keepdims=True)
        mask = (dotprod >= 0).float()
        d_norm_sq = (d*d).sum(2, keepdims=True)
        x_out = self.negative_slope * p + (1-self.negative_slope) * (mask*p + (1-mask)*(p-(dotprod/(d_norm_sq+EPS))*d))
        return x_out


class VNLinearAndLeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, use_batchnorm='norm', negative_slope=0.2):
        super(VNLinearLeakyReLU, self).__init__()
        self.dim = dim
        self.share_nonlinearity = share_nonlinearity
        self.use_batchnorm = use_batchnorm
        self.negative_slope = negative_slope
        
        self.linear = VNLinear(in_channels, out_channels)
        self.leaky_relu = VNLeakyReLU(out_channels, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope)
        
        # BatchNorm
        self.use_batchnorm = use_batchnorm
        if use_batchnorm != 'none':
            self.batchnorm = VNBatchNorm(out_channels, dim=dim, mode=use_batchnorm)
    
    def forward(self, x, **kwargs):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        # Conv
        x = self.linear(x)
        # InstanceNorm
        if self.use_batchnorm != 'none':
            x = self.batchnorm(x)
        # LeakyReLU
        x_out = self.leaky_relu(x)
        return x_out


class VNBatchNorm(nn.Module):
    def __init__(self, num_features, dim):
        super(VNBatchNorm, self).__init__()
        self.dim = dim
        if dim == 3 or dim == 4:
            self.bn = nn.BatchNorm1d(num_features)
        elif dim == 5:
            self.bn = nn.BatchNorm2d(num_features)
    
    def forward(self, x, mask=None):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        norm = torch.norm(x, dim=2) + EPS  # [B, N_feat, N_samples, ...]
        
        if mask is not None:
            expanded_mask = mask.expand_as(norm)
            norm = torch.where(expanded_mask, norm, torch.ones_like(norm))
            
        norm_bn = self.bn(norm)
        
        if mask is not None:
            norm_bn = torch.where(expanded_mask, norm_bn, norm)
            

        norm = norm.unsqueeze(2)
        norm_bn = norm_bn.unsqueeze(2)
        x = x / norm * norm_bn
        
        return x


class VNMaxPool(nn.Module):
    def __init__(self, in_channels):
        super(VNMaxPool, self).__init__()
        self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
    
    def forward(self, x, **kwargs):
        '''
        x: point features of shape [B, N_feat, 3, N_samples, ...]
        '''
        # print(x.shape)
        d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
        dotprod = (x*d).sum(2, keepdims=True)
        idx = dotprod.max(dim=-1, keepdim=False)[1]
        index_tuple = torch.meshgrid([torch.arange(j) for j in x.size()[:-1]]) + (idx,)
        x_max = x[index_tuple]
        return x_max


def mean_pool(x, dim=-1, keepdim=False):
    return x.mean(dim=dim, keepdim=keepdim)


def knn(x, k, mask=None):

    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)  # [B, N, N]

    if mask is not None:
        mask = mask.squeeze(1)  # [B, N]
        inf_mask = (~mask).float() * -1e9
        inf_mask = inf_mask.unsqueeze(1)  # [B, 1, N]
        pairwise_distance = pairwise_distance + inf_mask + inf_mask.transpose(1, 2)
        
        self_mask = torch.eye(mask.size(1)).unsqueeze(0).to(x.device)  # [1, N, N]
        self_mask = self_mask * (~mask).float().unsqueeze(1)  # [B, N, N]
        pairwise_distance = pairwise_distance * (1 - self_mask) + self_mask * 0
    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def get_graph_feature_cross(x, k=20, idx=None, mask=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k, mask=mask)  

    idx_base = torch.arange(0, batch_size).type_as(idx).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)

    _, num_dims, _ = x.size()
    num_dims = num_dims // 3

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims, 3)
    x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
    
    if mask is not None:
        mask = mask.squeeze(1)  # [B, N]
        masked_points = ~mask  # [B, N]
        masked_points = masked_points.unsqueeze(2).unsqueeze(3).unsqueeze(4)  # [B, N, 1, 1, 1]
        feature = torch.where(masked_points, x, feature)

    cross = torch.cross(feature, x, dim=-1)
    feature = torch.cat((feature - x, x, cross), dim=3).permute(0, 3, 4, 1, 2).contiguous()

    return feature

class VNSmall(nn.Module):
    def __init__(self, n_knn=20, hidden_dim=64, pooling="max"):
        super().__init__()
        self.n_knn = n_knn
        self.pooling = pooling
        self.conv_pos = VNLinearLeakyReLU(3, hidden_dim // 3, dim=5, negative_slope=0.0)
        self.conv1 = VNLinearLeakyReLU(hidden_dim // 3, hidden_dim // 3, dim=4, negative_slope=0.0)
        self.bn1 = VNBatchNorm(hidden_dim // 3, dim=4)
        self.conv2 = VNLinearLeakyReLU(hidden_dim // 3, 12 // 3, dim=4, negative_slope=0.0)
        self.dropout = nn.Dropout(p=0.5)

        if self.pooling == "max":
            self.pool = VNMaxPool(hidden_dim // 3)
        elif self.pooling == "mean":
            self.pool = mean_pool

        # use one linear layer to predict the output
        # self.conv = VNLinear(3, 12 // 3)

    def forward(self, point_cloud, labels=None, mask=None):
        # print(point_cloud.shape)
        feat = get_graph_feature_cross(point_cloud, k=self.n_knn, mask=mask)  
        # can change multiple layers here:
        point_cloud = self.conv_pos(feat)
        point_cloud = self.pool(point_cloud)
        
        out = self.bn1(self.conv1(point_cloud), mask)
        out = self.conv2(out)

        # out = self.dropout(out)

        # out = self.pool(self.conv(feat))
        # out = self.dropout(out)

        return out.mean(dim=-1)


class PointcloudCanonFunction(nn.Module):
    def __init__(self, n_knn=5, hidden_dim=64, pooling="max", **kwargs):
        super().__init__()
        self.for_canon = kwargs.get('for_canon', False)
        self.model = VNSmall(n_knn, hidden_dim, pooling)
    def forward(self, points, mask=None):
        pos_feats, mask = preprocess(points, dataset_type='modelnet', max_nodes=1024) # [batch_size, num_sample_points, 3]
        pos_feats = pos_feats.transpose(2, 1)
        vectors = self.model(pos_feats, None)
        rotation_vectors = vectors[:, :3]

        rotation_matrix = self.gram_schmidt(rotation_vectors)
        canonical_point_cloud = torch.bmm(pos_feats.transpose(1, 2), rotation_matrix.transpose(1, 2))
        canonical_point_cloud = canonical_point_cloud.transpose(1, 2)
        if self.for_canon:
            return rotation_matrix
        else:
            return rotation_matrix, canonical_point_cloud

    def gram_schmidt(self, vectors):
        v1 = vectors[:, 0]
        v1 = v1 / torch.norm(v1, dim=1, keepdim=True)
        v2 = (vectors[:, 1] - torch.sum(vectors[:, 1] * v1, dim=1, keepdim=True) * v1)
        v2 = v2 / torch.norm(v2, dim=1, keepdim=True)
        v3 =  (vectors[:, 2] - torch.sum(vectors[:, 2] * v1, dim=1, keepdim=True) * v1 - 
              torch.sum(vectors[:, 2] * v2, dim=1, keepdim=True) * v2)
        v3 =  v3 / torch.norm(v3, dim=1, keepdim=True)
        
        R = torch.stack([v1, v2, v3], dim=1)
        det = torch.det(R)
        neg_det_mask = (det < 0).view(-1, 1, 1)
        R_fixed = R.clone()
        R_fixed[:, :, 2] = torch.where(neg_det_mask.squeeze(1), -R[:, :, 2], R[:, :, 2])
        
        return R_fixed




def graphormer_base_architecture(args):
    args.blocks = getattr(args, "blocks", 4)
    args.layers = getattr(args, "layers", 8)
    args.embed_dim = getattr(args, "embed_dim", 128)
    args.ffn_embed_dim = getattr(args, "ffn_embed_dim", 128)
    args.attention_heads = getattr(args, "attention_heads", 32)
    args.input_dropout = getattr(args, "input_dropout", 0.0)
    args.dropout = getattr(args, "dropout", 0.1)
    args.attention_dropout = getattr(args, "attention_dropout", 0.1)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.node_loss_weight = getattr(args, "node_loss_weight", 1)
    args.min_node_loss_weight = getattr(args, "min_node_loss_weight", 1)
    args.eng_loss_weight = getattr(args, "eng_loss_weight", 1)
    args.num_kernel = getattr(args, "num_kernel", 32)
    args.num_outputs= getattr(args, "num_outputs", 19)


# Step 2: Define a simple neural network
class SimpleCNN(nn.Module):
    # can also use for direct prediction
    def __init__(self, num_classes=10): #, augment=False):
        super(SimpleCNN, self).__init__()
        self.conv_net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),  # 28x28 → 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),                             # 28x28 → 14x14

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 14x14 → 14x14
            nn.ReLU(),
            nn.MaxPool2d(2)                              # 14x14 → 7x7
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        # self.augment = augment
        # self.rotations = [0, 90, 180, 270]

    def forward(self, x, **kwargs):
        # if self.augment:
        #   x = transforms.functional.rotate(x, random.choice(self.rotations))
        x = self.conv_net(x)
        x = self.fc(x)
        return x

class GroupAveragedSimpleCNN(SimpleCNN):
    def __init__(self, num_classes=10, group='C4'):
        super().__init__(num_classes)
        if group == 'C4':
            self.angles = [0, 90, 180, 270]
        else:
            raise NotImplementedError(f"Group {group} not supported")

    def forward(self, x, **kwargs):
        outputs = []
        for angle in self.angles:
            x_rot = transforms.functional.rotate(x, angle)
            y = super().forward(x_rot, **kwargs)
            outputs.append(y)

        outputs = torch.stack(outputs, dim=0)  # shape: [4, B, C]
        return outputs.mean(dim=0)             # shape: [B, C]
    
def rotate_batch(x, k):
    return torch.rot90(x, k=k, dims=[2, 3])

class BaseCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(64, 4)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class RotationEquivariantClassifier(nn.Module):
    def __init__(self, softmax_temp=1, **kwargs):
        super().__init__()
        self.shared_cnn = BaseCNN()
        self.softmax_temp = softmax_temp

    def forward(self, x, **kwargs):
        B = x.size(0)
        rotated_inputs = torch.cat(
            [rotate_batch(x, k) for k in range(4)],
            dim=0
        )  # shape: (4B, 1, 28, 28)
        logits = self.shared_cnn(rotated_inputs)  # shape: (4B, 4)
        logits = logits.view(4, B, 4).transpose(0, 1)  # shape: (B, 4, 4)

        #return logits  # shape: (B, 4 rotations, 4 classes)
        avg_logits = logits.mean(dim=1)

        avg_logits = torch.softmax(avg_logits / self.softmax_temp, dim=-1)
        
        return avg_logits

        # not differentiable
        #pred_classes = torch.argmax(logits.mean(dim=1), dim=-1)   # (B,) predicted class index

        # return pred_classes
