class ForcePredictionLayer(nn.Module):
    def __init__(self, x_dim, hidden_dim, activation):
        super(ForcePredictionLayer, self).__init__()
        self.x_dim = x_dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.edge_mlp = BaseMLP(input_dim=2 * x_dim,  # For pairs of nodes
                                hidden_dim=hidden_dim,
                                output_dim=hidden_dim,
                                activation=activation)
        self.node_mlp = BaseMLP(input_dim=hidden_dim,
                                hidden_dim=hidden_dim,
                                output_dim=3,  # Output dimension for 3D force
                                activation=activation)

    def forward(self, pos, vel):
        x = torch.cat((pos, vel), dim=-1)  # Combine position and velocity

        # Create a complete graph
        n = x.size(0)
        edge_index = torch.combinations(torch.arange(n), r=2)

        # Compute edge features
        edge_features = torch.cat((x[edge_index[:, 0]], x[edge_index[:, 1]]), dim=-1)
        edge_features = self.edge_mlp(edge_features)

        # Aggregate edge features for each node
        agg_edge_features = scatter(edge_features, edge_index[:, 0], dim=0, reduce='mean', dim_size=n)

        # Predict forces for each node
        force_output = self.node_mlp(agg_edge_features)
        return force_output
