import torch
from torch import nn

from libs import GCL, E_GCL, E_GCL_vel, GCL_rf_vel, Clof_GCL, GMNLayer, CFIN_layer, CFIN_layer_diff_invar
import numpy as np
import logging


class CFINs(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', dropout_rate = 0.5, act_fn= nn.SiLU(), n_layers=4, recurrent = False, norm_diff = False, dim=3, no_infer = False, order = 2):
        super(CFINs, self).__init__()
        
        self.recurrent = recurrent
        self.dim = dim
        self.hidden_nf = hidden_nf
        self.in_h_nf = in_node_nf
        self.in_edge_nf = in_edge_nf
        self.dropout_rate = dropout_rate
        self.norm_diff = norm_diff
        
        node_embed_nf = 3 + self.in_h_nf # dim of inner-agent invariants and dim of h (initial node features' dimension)
        if self.dim == 2:
            edge_embed_nf = 2 + self.in_edge_nf # dim of inter-agent invariants and dim of aij (initial edge features' dimension)
        elif self.dim == 3:
            edge_embed_nf = 3 + self.in_edge_nf # dim of inter-agent invariants and dim of aij (initial edge features' dimension)
        
        self.fuse_node = nn.Sequential(
            nn.Linear(node_embed_nf, self.hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, hidden_nf),
            act_fn,
            nn.LayerNorm(hidden_nf)
            )

        self.order = order
        self.category_edge = 1
        self.no_infer = False
        if not self.no_infer:
            # node mlp
            self.node_mlp = nn.Sequential(
                nn.Linear(hidden_nf + int(1*hidden_nf), hidden_nf),
                act_fn,
                nn.Dropout(self.dropout_rate),
                nn.Linear(hidden_nf, hidden_nf),
                act_fn,
                nn.LayerNorm(hidden_nf)
                )
            
            # edge mlp
            self.edge_mlp = nn.Sequential(
                nn.Linear(self.hidden_nf*2 + self.hidden_nf // 2, self.hidden_nf),
                act_fn,
                nn.Dropout(self.dropout_rate),
                nn.Linear(self.hidden_nf, self.hidden_nf),
                act_fn,
                nn.LayerNorm(hidden_nf))
            
            # interaction infer mlp
            self.infer_mlp = nn.Sequential(
                nn.Linear(self.hidden_nf*2 + self.hidden_nf, self.hidden_nf),
                act_fn,
                nn.Linear(self.hidden_nf, self.category_edge),
                nn.Sigmoid())

        self.device = device
        self.n_layers = n_layers
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        for i in range(0, n_layers):
            self.add_module("cfins_%d" % i, CFIN_layer(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_embed_nf, dropout_rate=self.dropout_rate, act_fn=act_fn, recurrent=False, norm_diff = self.norm_diff, tanh=False, order = self.order, out_dim = self.dim))
        self.to(self.device)
        self.params = self.__str__()

    def __str__(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        print('Network Size', params)
        logging.info('Network Size {}'.format(params))
        return str(params)

    def forward(self, h, loc, vel, edges, edge_attr, node_attr=None, num_agents = 5):
        """
            h: the first dimension of shape should be (batchsize * num_agents
            loc, vel: the first dimension of shape should be (batchsize* num_agents
            edge_attr: the first dimension of shape should be (batchsize * num_edges
        """
        
        row, col = edges
        loc = loc.reshape(-1, num_agents, self.dim)
        loc_central, center = self.centralized(loc) # centralized
        vel = vel.view(-1, vel.size(-1)) 
        
        
        i_nt = self.coord2inter_diff_invar(edges, loc_central, vel, self.norm_diff) 
        i_nn = self.coord2inner_diff_invar(loc_central, vel, self.norm_diff)  
        h = torch.cat([h, i_nn], dim=-1)                # (batchsize* num_agents, dim_inn + dim_h)
        h = self.fuse_node(h)                           # (batchsize* num_agents, hidden_nf)
        
        if edge_attr is None:
            edge_attr = i_nt
        else:
            edge_attr = torch.cat([edge_attr, i_nt], dim=-1) # (num_edges * batchsize, dim_edge_attr + dim_inter)    
        c = 1 # no infer
        
        for i in range(0, self.n_layers):
            h, loc_central, vel, _ = self._modules["cfins_%d" % i](h, edges, c, loc_central, vel, edge_attr=edge_attr) 
        loc = loc_central.reshape(-1, num_agents, self.dim) + center
        loc = loc.reshape(-1, self.dim)
        return loc, vel

    def centralized(self, loc):
        center = torch.mean(loc, dim=1, keepdim=True) # (batchsize, 1, dim_coord)
        loc_central = loc - center # (batchsize, num_agents, dim_coord)
        
        loc_central = loc_central.view(-1, loc_central.size(-1))# (batchsize *num_agents, dim_coord)
        
        return loc_central, center
        
    def infer_model(self, source, target, edge_attr):
        edge_in = torch.cat([source, target], dim=-1) # (num_edges * batchsize, 2 * hidden_nf)
        if edge_attr is not None:
            edge_in = torch.cat([edge_in, edge_attr], dim=-1) # (num_edges * batchsize, 2 * hidden_nf + hidden_nf // 2)
        out = self.infer_mlp(edge_in) # (num_edges * batchsize, hidden_nf)
        return out 
    
    def edge_model(self, source, target, edge_attr):
        edge_in = torch.cat([source, target], dim=-1) # (num_edges * batchsize, 2 * hidden_nf)
        if edge_attr is not None:
            edge_in = torch.cat([edge_in, edge_attr], dim=-1)
        out = self.edge_mlp(edge_in) # (num_edges * batchsize, mlp_out)
        return out

    def node_model(self, h, edge_index, edge_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0))
        out = torch.cat([h, agg], dim=-1)
        out = self.node_mlp(out)
        if self.recurrent:
            out = out + h
        return out
    
    def interaction_graph(self, h, edge_index, edge_attr = None): 
        # c represents the importance of edges e_ij
        row, col = edge_index
        edge_m = self.edge_model(h[row], h[col], edge_attr) # (num_edges * batchsize, hidden_nf)
        h = self.node_model(h, edge_index, edge_m) # (batchsize* num_agents, hidden_nf)
        
        c = self.infer_model(h[row], h[col], edge_m) # (num_edges * batchsize, category_edge == 1)
        return c
    
    def coord2inner_diff_invar(self, loc, vel, norm):
        """coord to complete inner diff invariants
        """
        # normalization
        if norm == True:
            epsilon = 1e-8
            loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
            loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
            loc = loc/loc_norm
        
            vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
            vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
            vel = vel/vel_norm
            
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize, num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize, num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize, num_agents, 1)
            
        I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize, num_agents, 3)
            
        return I_nn
        
    def coord2inter_diff_invar(self, edge_index, loc, vel, norm):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        
        # normalization
        if norm == True:
            epsilon = 1e-8
            loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
            loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
            loc = loc/loc_norm
        
            vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
            vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
            vel = vel/vel_norm

        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute vi*(pi \cross pj)
            cross = torch.cross(loc[row], loc[col], dim=-1)
            inter_thr = torch.sum(vel[row] * cross, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            
            # compute (pi \cross pj)*(vi \cross vj)
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            
            I_nt = torch.cat([inter_loc, inter_vel, inter_thr], dim=-1)  # shape (num_edges * batchsize, 3)
            
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")

class ClofNet(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4,
        coords_weight=1.0, recurrent=True, norm_diff=True, tanh=False,
    ):
        super(ClofNet, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.embedding_node = nn.Linear(in_node_nf, self.hidden_nf)

        edge_embed_dim = 10
        self.fuse_edge = nn.Sequential(
            nn.Linear(edge_embed_dim, self.hidden_nf // 2), act_fn,
            nn.Linear(self.hidden_nf // 2, self.hidden_nf // 2), act_fn)

        self.norm_diff = True
        for i in range(0, self.n_layers):
            self.add_module(
                "gcl_%d" % i,
                Clof_GCL(
                    input_nf=self.hidden_nf,
                    output_nf=self.hidden_nf,
                    hidden_nf=self.hidden_nf,
                    edges_in_d=self.hidden_nf // 2,
                    act_fn=act_fn,
                    recurrent=recurrent,
                    coords_weight=coords_weight,
                    norm_diff=norm_diff,
                    tanh=tanh,
                ),
            )
        self.to(self.device)
        self.params = self.__str__()

    def __str__(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        print('Network Size', params)
        logging.info('Network Size {}'.format(params))
        return str(params)

    def coord2localframe(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)
        coord_cross = torch.cross(coord[row], coord[col], dim=-1)
        if self.norm_diff:
            norm = torch.sqrt(radial) + 1
            coord_diff = coord_diff / norm
            cross_norm = (torch.sqrt(
                torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1
            coord_cross = coord_cross / cross_norm
        coord_vertical = torch.cross(coord_diff, coord_cross, dim=-1)
        return coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1)

    def scalarization(self, edges, x, vel):
        coord_diff, coord_cross, coord_vertical = self.coord2localframe(edges, x)
        # Geometric Vectors Scalarization
        row, col = edges
        edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=1) 
        r_i = x[row] 
        r_j = x[col]
        v_i = vel[row]
        v_j = vel[col]
        coff_i = torch.matmul(edge_basis,
                              r_i.unsqueeze(-1)).squeeze(-1) 
         
        coff_j = torch.matmul(edge_basis,
                              r_j.unsqueeze(-1)).squeeze(-1)   
        # Calculate angle information in local frames
        coff_mul = coff_i * coff_j  # [E, 3]
        coff_i_norm = coff_i.norm(dim=-1, keepdim=True)
        coff_j_norm = coff_j.norm(dim=-1, keepdim=True)
        pesudo_cos = coff_mul.sum(
            dim=-1, keepdim=True) / (coff_i_norm + 1e-5) / (coff_j_norm + 1e-5)
        pesudo_sin = torch.sqrt(1 - pesudo_cos**2)
        pesudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1)
        coff_feat = torch.cat([pesudo_angle, coff_i, coff_j],
                              dim=-1)  #[E, 10]
        return coff_feat

    def forward(self, h, x, edges, vel, edge_attr, node_attr=None, n_nodes=5):
        h = self.embedding_node(h)
        x = x.reshape(-1, n_nodes, 3)
        centroid = torch.mean(x, dim=1, keepdim=True)
        x_center = (x - centroid).reshape(-1, 3)

        coff_feat = self.scalarization(edges, x_center, vel)
        edge_feat = torch.cat([edge_attr, coff_feat], dim=-1)
        edge_feat = self.fuse_edge(edge_feat)

        for i in range(0, self.n_layers):
            h, x_center, _ = self._modules["gcl_%d" % i](
                h, edges, x_center, vel, edge_attr=edge_feat, node_attr=node_attr)

        x = x_center.reshape(-1, n_nodes, 3) + centroid
        x = x.reshape(-1, 3)
        return x

class ClofNet_CIs(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4,
        coords_weight=1.0, recurrent=True, norm_diff=True, tanh=False,
    ):
        super(ClofNet_CIs, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        in_node_nf = in_node_nf + 3
        self.embedding_node = nn.Linear(in_node_nf, self.hidden_nf)

        edge_embed_dim = 10 + 3
        self.fuse_edge = nn.Sequential(
            nn.Linear(edge_embed_dim, self.hidden_nf // 2), act_fn,
            nn.Linear(self.hidden_nf // 2, self.hidden_nf // 2), act_fn)

        self.norm_diff = True
        for i in range(0, self.n_layers):
            self.add_module(
                "gcl_%d" % i,
                Clof_GCL(
                    input_nf=self.hidden_nf,
                    output_nf=self.hidden_nf,
                    hidden_nf=self.hidden_nf,
                    edges_in_d=self.hidden_nf // 2,
                    act_fn=act_fn,
                    recurrent=recurrent,
                    coords_weight=coords_weight,
                    norm_diff=norm_diff,
                    tanh=tanh,
                ),
            )
        self.to(self.device)
        self.params = self.__str__()

    def __str__(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        print('Network Size', params)
        logging.info('Network Size {}'.format(params))
        return str(params)

    def coord2localframe(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)
        coord_cross = torch.cross(coord[row], coord[col], dim=-1)
        if self.norm_diff:
            norm = torch.sqrt(radial) + 1
            coord_diff = coord_diff / norm
            cross_norm = (torch.sqrt(
                torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1
            coord_cross = coord_cross / cross_norm
        coord_vertical = torch.cross(coord_diff, coord_cross, dim=-1)
        return coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1)
    def coord2inner_diff_invar(self, loc, vel):
        """coord to complete inner diff invariants
        """
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize, num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize, num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize, num_agents, 1)
            
        I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize, num_agents, 3)
            
        return I_nn

    
    def coord2inter_diff_invar(self, edge_index, loc, vel):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        
        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute vi*(pi \cross pj)
            cross = torch.cross(loc[row], loc[col], dim=-1)
            inter_thr = torch.sum(vel[row] * cross, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            
            # compute (pi \cross pj)*(vi \cross vj)
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            
            I_nt = torch.cat([inter_loc, inter_vel, inter_thr], dim=-1)  # shape (num_edges * batchsize, 4)
            
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")
    
    def scalarization(self, edges, x, vel):
        coord_diff, coord_cross, coord_vertical = self.coord2localframe(edges, x)
        # Geometric Vectors Scalarization
        row, col = edges
        edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=1) 
        r_i = x[row] 
        r_j = x[col]
        v_i = vel[row]
        v_j = vel[col]
        coff_i = torch.matmul(edge_basis,
                              r_i.unsqueeze(-1)).squeeze(-1) 
         
        coff_j = torch.matmul(edge_basis,
                              r_j.unsqueeze(-1)).squeeze(-1)   
        # Calculate angle information in local frames
        coff_mul = coff_i * coff_j  # [E, 3]
        coff_i_norm = coff_i.norm(dim=-1, keepdim=True)
        coff_j_norm = coff_j.norm(dim=-1, keepdim=True)
        pesudo_cos = coff_mul.sum(
            dim=-1, keepdim=True) / (coff_i_norm + 1e-5) / (coff_j_norm + 1e-5)
        pesudo_sin = torch.sqrt(1 - pesudo_cos**2)
        pesudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1)
        coff_feat = torch.cat([pesudo_angle, coff_i, coff_j],
                              dim=-1)  #[E, 10]
        return coff_feat

    def forward(self, h, x, edges, vel, edge_attr, node_attr=None, n_nodes=5):
        x = x.reshape(-1, n_nodes, 3)
        centroid = torch.mean(x, dim=1, keepdim=True)
        x_center = (x - centroid).reshape(-1, 3)

        i_nt = self.coord2inter_diff_invar(edges, x_center, vel) 
        i_nn = self.coord2inner_diff_invar(x_center, vel)

        h = torch.cat([h, i_nn], dim=-1) # (batchsize* num_agents, dim_inn + dim_h)
        if edge_attr is None:
            edge_attr = i_nt
        else:
            edge_attr = torch.cat([edge_attr, i_nt], dim=-1) # (num_edges * batchsize, dim_edge_attr + dim_inter)
        
        h = self.embedding_node(h)
        coff_feat = self.scalarization(edges, x_center, vel)
        edge_feat = torch.cat([edge_attr, coff_feat], dim=-1)
        edge_feat = self.fuse_edge(edge_feat)

        for i in range(0, self.n_layers):
            h, x_center, _ = self._modules["gcl_%d" % i](
                h, edges, x_center, vel, edge_attr=edge_feat, node_attr=node_attr)

        x = x_center.reshape(-1, n_nodes, 3) + centroid
        x = x.reshape(-1, 3)
        return x

class GNN(nn.Module):
    def __init__(self, input_dim, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, attention=0, recurrent=False):
        super(GNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        ### Encoder
        #self.add_module("gcl_0", GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention, recurrent=recurrent))
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention, recurrent=recurrent))

        self.decoder = nn.Sequential(nn.Linear(hidden_nf, hidden_nf),
                              act_fn,
                              nn.Linear(hidden_nf, 3))
        self.embedding = nn.Sequential(nn.Linear(input_dim, hidden_nf))
        self.to(self.device)


    def forward(self, nodes, edges, edge_attr=None):
        h = self.embedding(nodes)
        #h, _ = self._modules["gcl_0"](h, edges, edge_attr=edge_attr)
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr)
        #return h
        return self.decoder(h)

def get_velocity_attr(loc, vel, rows, cols):
    #return  torch.cat([vel[rows], vel[cols]], dim=1)

    diff = loc[cols] - loc[rows]
    norm = torch.norm(diff, p=2, dim=1).unsqueeze(1)
    u = diff/norm
    va, vb = vel[rows] * u, vel[cols] * u
    va, vb = torch.sum(va, dim=1).unsqueeze(1), torch.sum(vb, dim=1).unsqueeze(1)
    return va

class EGNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.LeakyReLU(0.2), n_layers=4, coords_weight=1.0):
        super(EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        #self.reg = reg
        ### Encoder
        #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight))
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=True, coords_weight=coords_weight))
        self.to(self.device)


    def forward(self, h, x, edges, edge_attr, vel=None):
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            #if vel is not None:
                #vel_attr = get_velocity_attr(x, vel, edges[0], edges[1])
                #edge_attr = torch.cat([edge_attr0, vel_attr], dim=1).detach()
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr)
        return x

class EGNN_vel(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0, recurrent=False, norm_diff=False, tanh=False):
        super(EGNN_vel, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        #self.reg = reg
        ### Encoder
        #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight))
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL_vel(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, coords_weight=coords_weight, recurrent=recurrent, norm_diff=norm_diff, tanh=tanh))
        self.to(self.device)


    def forward(self, h, x, edges, vel, edge_attr):
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, vel, edge_attr=edge_attr)
        return x
    
class EGNN_vel_CIs(nn.Module):
    """
        EGNN_vel with complete invariants
    """
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0, recurrent=False, norm_diff=False, tanh=False):
        super(EGNN_vel_CIs, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        node_embed_nf = 3 + in_node_nf # dim of inner-agent invariants and dim of h (initial node features' dimension)
        edge_embed_nf = 3 + in_edge_nf # dim of inter-agent invariants and dim of aij (initial edge features' dimension)
        
        self.embedding = nn.Linear(node_embed_nf, self.hidden_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL_vel(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_embed_nf, act_fn=act_fn, coords_weight=coords_weight, recurrent=recurrent, norm_diff=norm_diff, tanh=tanh))
        self.to(self.device)

    def coord2inner_diff_invar(self, edge_index, loc, vel):
        """coord to complete inner diff invariants
        """
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize, num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize, num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize, num_agents, 1)
            
        I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize, num_agents, 3)
            
        return I_nn
    
    def coord2inter_diff_invar(self, edge_index, loc, vel):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        
        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute vi*(pi \cross pj)
            cross = torch.cross(loc[row], loc[col], dim=-1)
            inter_thr = torch.sum(vel[row] * cross, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            
            # compute (pi \cross pj)*(vi \cross vj)
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            
            I_nt = torch.cat([inter_loc, inter_vel, inter_thr], dim=-1)  # shape (num_edges * batchsize, 4)
            
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")
     
    def forward(self, h, x, edges, vel, edge_attr, num_agents = 5):
        
        i_nt = self.coord2inter_diff_invar(edges, x, vel) 
        i_nn = self.coord2inner_diff_invar(edges, x, vel)  
        
        h = torch.cat([h, i_nn], dim=-1) # (batchsize* num_agents, dim_inn + dim_h)
        if edge_attr is None:
            edge_attr = i_nt
        else:
            edge_attr = torch.cat([edge_attr, i_nt], dim=-1) # (num_edges * batchsize, dim_edge_attr + dim_inter)
        h = self.embedding(h)
        
        x = x.reshape(-1, num_agents, 3)
        centroid = torch.mean(x, dim=1, keepdim=True)
        x_center = (x - centroid).reshape(-1, 3)
        
        for i in range(0, self.n_layers):
            h, x_center, _ = self._modules["gcl_%d" % i](h, edges, x_center, vel, edge_attr=edge_attr)
        x = x_center.reshape(-1, num_agents, 3) + centroid
        x = x.reshape(-1, 3)
        return x

class GMN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0,
                 recurrent=False, norm_diff=False, tanh=False):
        """
        Graph Mechanics Networks.
        :param in_node_nf: input node feature dimension
        :param in_edge_nf: input edge feature dimension
        :param hidden_nf: hidden dimension
        :param device: device
        :param act_fn: activation function
        :param n_layers: the number of layers
        :param coords_weight: coords weight, inherited from EGNN
        :param recurrent: residual connection on x
        :param norm_diff: normalize the distance, inherited
        :param tanh: Tanh activation, inherited
        """
        super(GMN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers

        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        for i in range(n_layers):
            self.add_module("gcl_%d" % i, GMNLayer(self.hidden_nf, self.hidden_nf, self.hidden_nf,
                                                   edges_in_d=in_edge_nf, act_fn=act_fn,
                                                   coords_weight=coords_weight,
                                                   recurrent=recurrent, norm_diff=norm_diff, tanh=tanh)
                            )

        self.to(self.device)

    def forward(self, h, x, edges, v, edge_attr):
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, v, edge_attr=edge_attr)

        return x
    
class GMN_CIs(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0,
                 recurrent=False, norm_diff=False, tanh=False):
        """
        Graph Mechanics Networks.
        :param in_node_nf: input node feature dimension
        :param in_edge_nf: input edge feature dimension
        :param hidden_nf: hidden dimension
        :param device: device
        :param act_fn: activation function
        :param n_layers: the number of layers
        :param coords_weight: coords weight, inherited from EGNN
        :param recurrent: residual connection on x
        :param norm_diff: normalize the distance, inherited
        :param tanh: Tanh activation, inherited
        """
        super(GMN_CIs, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers

        self.embedding = nn.Linear(in_node_nf + 3, self.hidden_nf)
        for i in range(n_layers):
            self.add_module("gcl_%d" % i, GMNLayer(self.hidden_nf, self.hidden_nf, self.hidden_nf,
                                                   edges_in_d=in_edge_nf + 3, act_fn=act_fn,
                                                   coords_weight=coords_weight,
                                                   recurrent=recurrent, norm_diff=norm_diff, tanh=tanh)
                            )

        self.to(self.device)
    

    def coord2inner_diff_invar(self, loc, vel):
        """coord to complete inner diff invariants
        """
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize, num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize, num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize, num_agents, 1)
            
        I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize, num_agents, 3)
            
        return I_nn

    
    def coord2inter_diff_invar(self, edge_index, loc, vel):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute vi*(pi \cross pj)
            cross = torch.cross(loc[row], loc[col], dim=-1)
            inter_thr = torch.sum(vel[row] * cross, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            
            # compute (pi \cross pj)*(vi \cross vj)
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            
            I_nt = torch.cat([inter_loc, inter_vel, inter_thr], dim=-1)  # shape (num_edges * batchsize, 4)
            
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")


    def forward(self, h, x, edges, v, edge_attr, n_nodes=5):
        x = x.reshape(-1, n_nodes, 3)
        centroid = torch.mean(x, dim=1, keepdim=True)
        x_center = (x - centroid).reshape(-1, 3)
        
        i_nt = self.coord2inter_diff_invar(edges, x_center, v) 
        i_nn = self.coord2inner_diff_invar(x_center, v)

        h = torch.cat([h, i_nn], dim=-1) # (batchsize* num_agents, dim_inn + dim_h)
        if edge_attr is None:
            edge_attr = i_nt
        else:
            edge_attr = torch.cat([edge_attr, i_nt], dim=-1) # (num_edges * batchsize, dim_edge_attr + dim_inter)

        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, x_center, _ = self._modules["gcl_%d" % i](h, edges, x_center, v, edge_attr=edge_attr)

        x = x_center.reshape(-1, n_nodes, 3) + centroid
        x = x.reshape(-1, 3)
        return x

class RF_vel(nn.Module):
    def __init__(self, hidden_nf, edge_attr_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4):
        super(RF_vel, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        #self.reg = reg
        ### Encoder
        #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight))
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL_rf_vel(nf=hidden_nf, edge_attr_nf=edge_attr_nf, act_fn=act_fn))
        self.to(self.device)


    def forward(self, vel_norm, x, edges, vel, edge_attr):
        for i in range(0, self.n_layers):
            x, _ = self._modules["gcl_%d" % i](x, vel_norm, vel, edges, edge_attr)
        return x

class Baseline(nn.Module):
    def __init__(self, device='cpu'):
        super(Baseline, self).__init__()
        self.dummy = nn.Linear(1, 1)
        self.device = device
        self.to(self.device)

    def forward(self, loc):
        return loc

class Linear(nn.Module):
    def __init__(self, input_nf, output_nf, device='cpu'):
        super(Linear, self).__init__()
        self.linear = nn.Linear(input_nf, output_nf)
        self.device = device
        self.to(self.device)

    def forward(self, input):
        return self.linear(input)

class Linear_dynamics(nn.Module):
    def __init__(self, device='cpu'):
        super(Linear_dynamics, self).__init__()
        self.time = nn.Parameter(torch.ones(1)*0.7)
        self.device = device
        self.to(self.device)

    def forward(self, x, v):
        return x + v*self.time
    
class CFINs_diff_invar(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', dropout_rate = 0.5, act_fn= nn.SiLU(), n_layers=4, recurrent = False, dim=3, no_infer = False, order = 2, case = 0, frame = 0):
        super(CFINs_diff_invar, self).__init__()
        self.recurrent = recurrent
        self.dim = dim
        self.hidden_nf = hidden_nf
        self.in_h_nf = in_node_nf
        self.in_edge_nf = in_edge_nf
        self.dropout_rate = dropout_rate
        self.case = case
        self.frame = frame
        
        node_embed_nf = 3 + self.in_h_nf # dim of inner-agent invariants and dim of h (initial node features' dimension)
        if self.dim == 2:
            edge_embed_nf = 2 + self.in_edge_nf # dim of inter-agent invariants and dim of aij (initial edge features' dimension)
        elif self.dim == 3:
            edge_embed_nf = 4 + self.in_edge_nf # dim of inter-agent invariants and dim of aij (initial edge features' dimension)
        
        self.fuse_node = nn.Linear(self.in_h_nf, self.hidden_nf)
        # self.fuse_edge = nn.Linear(edge_embed_nf, self.hidden_nf // 2)
        # self.fuse_node = nn.Sequential(
        #     nn.Linear(node_embed_nf, self.hidden_nf), 
        #     act_fn,
        #     nn.Dropout(self.dropout_rate),
        #     nn.Linear(self.hidden_nf, self.hidden_nf), 
        #     act_fn)
        
        # self.fuse_edge = nn.Sequential(
        #     nn.Linear(edge_embed_nf, self.hidden_nf // 2), 
        #     act_fn,
        #     nn.Linear(self.hidden_nf // 2, self.hidden_nf // 2),
        #     act_fn)
        
        self.order = order
        self.category_edge = 1
        self.no_infer = False
        if not self.no_infer:
            # node mlp
            self.node_mlp = nn.Sequential(
                nn.Linear(hidden_nf + int(1*hidden_nf), hidden_nf),
                act_fn,
                nn.Dropout(self.dropout_rate),
                nn.Linear(hidden_nf, hidden_nf),
                nn.LayerNorm(hidden_nf),
                )
            
            # edge mlp
            self.edge_mlp = nn.Sequential(
                nn.Linear(self.hidden_nf*2 + self.hidden_nf // 2, self.hidden_nf),
                act_fn,
                nn.Dropout(self.dropout_rate),
                nn.Linear(self.hidden_nf, self.hidden_nf),
                nn.LayerNorm(hidden_nf),
                act_fn)
            
            # interaction infer mlp
            self.infer_mlp = nn.Sequential(
                nn.Linear(self.hidden_nf*2 + self.hidden_nf, self.hidden_nf),
                act_fn,
                nn.Linear(self.hidden_nf, self.category_edge),
                nn.Sigmoid())

        self.device = device
        self.n_layers = n_layers
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        for i in range(0, n_layers):
            self.add_module("cfins_%d" % i, CFIN_layer_diff_invar(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=self.in_edge_nf, dropout_rate=self.dropout_rate, act_fn=act_fn, recurrent=False, tanh=True, order = self.order, out_dim = self.dim, case = self.case, frame = self.frame))
        self.to(self.device)
        self.params = self.__str__()

    def __str__(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        print('Network Size', params)
        logging.info('Network Size {}'.format(params))
        return str(params)

    def forward(self, h, loc, vel, edges, edge_attr, node_attr=None, num_agents = 5):
        """
            h: the first dimension of shape should be (batchsize * num_agents
            loc, vel: the first dimension of shape should be (batchsize* num_agents
            edge_attr: the first dimension of shape should be (batchsize * num_edges
        """
        
        row, col = edges
        loc = loc.reshape(-1, num_agents, self.dim)
        loc_central, center = self.centralized(loc) # centralized
        vel = vel.view(-1, vel.size(-1)) 
        
        h = self.fuse_node(h)                           # (batchsize* num_agents, hidden_nf)
        c = 1
        for i in range(0, self.n_layers):
            h, loc_central, vel, _ = self._modules["cfins_%d" % i](h, edges, c, loc_central, vel, edge_attr=edge_attr) 
            # h, loc, _, _ = self._modules["cfins_%d" % i](h, edges, c, loc, vel, edge_attr=edge_attr) 
        loc = loc_central.reshape(-1, num_agents, self.dim) + center
        loc = loc.reshape(-1, self.dim)
        return loc, vel

    def centralized(self, loc):
        center = torch.mean(loc, dim=1, keepdim=True) # (batchsize, 1, dim_coord)
        loc_central = loc - center # (batchsize, num_agents, dim_coord)
        
        loc_central = loc_central.view(-1, loc_central.size(-1))# (batchsize *num_agents, dim_coord)
        
        return loc_central, center
        
    def infer_model(self, source, target, edge_attr):
        edge_in = torch.cat([source, target], dim=-1) # (num_edges * batchsize, 2 * hidden_nf)
        if edge_attr is not None:
            edge_in = torch.cat([edge_in, edge_attr], dim=-1) # (num_edges * batchsize, 2 * hidden_nf + hidden_nf // 2)
        out = self.infer_mlp(edge_in) # (num_edges * batchsize, hidden_nf)
        return out 
    
    def edge_model(self, source, target, edge_attr):
        edge_in = torch.cat([source, target], dim=-1) # (num_edges * batchsize, 2 * hidden_nf)
        if edge_attr is not None:
            edge_in = torch.cat([edge_in, edge_attr], dim=-1)
        out = self.edge_mlp(edge_in) # (num_edges * batchsize, mlp_out)
        return out

    def node_model(self, h, edge_index, edge_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0))
        out = torch.cat([h, agg], dim=-1)
        out = self.node_mlp(out)
        if self.recurrent:
            out = out + h
        return out
    
    def interaction_graph(self, h, edge_index, edge_attr = None): 
        # c represents the importance of edges e_ij
        row, col = edge_index
        edge_m = self.edge_model(h[row], h[col], edge_attr) # (num_edges * batchsize, hidden_nf)
        h = self.node_model(h, edge_index, edge_m) # (batchsize* num_agents, hidden_nf)
        
        c = self.infer_model(h[row], h[col], edge_m) # (num_edges * batchsize, category_edge == 1)
        return c
    
    def coord2inner_diff_invar(self, loc, vel):
        """coord to complete inner diff invariants
        """
        # normalization
        epsilon = 1e-8
        loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
        loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
        loc = loc/loc_norm
        
        vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
        vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
        vel = vel/vel_norm
        
        
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize, num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize, num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize, num_agents, 1)
            
        I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize, num_agents, 3)
            
        return I_nn
        
    def coord2inter_diff_invar(self, edge_index, loc, vel):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        
        # normalization
        epsilon = 1e-8
        loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
        loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
        loc = loc/loc_norm
        
        vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
        vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
        vel = vel/vel_norm
        # check the number of dimension of coord
        # shape of loc[row]: (num_edges*batch_size, dim_loc)
        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute vi*(pi \cross pj)
            cross = torch.cross(loc[row], loc[col], dim=-1)
            inter_thr = torch.sum(vel[row] * cross, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            
            # compute (pi \cross pj)*(vi \cross vj)
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            inter_four = torch.sum(cross_loc * cross_vel, dim=-1, keepdim=True) #(pi x pj)*(vi x vj) shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel, inter_thr, inter_four], dim=-1)  # shape (num_edges * batchsize, 4)
            
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")

def unsorted_segment_sum(data, segment_ids, num_segments):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`."""
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)