from torch import nn
import torch

class MLP(nn.Module):
    """ a simple 4-layer MLP """

    def __init__(self, nin, nout, nh):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(nin, nh),
            nn.LeakyReLU(0.2),
            nn.Linear(nh, nh),
            nn.LeakyReLU(0.2),
            nn.Linear(nh, nh),
            nn.LeakyReLU(0.2),
            nn.Linear(nh, nout),
        )

    def forward(self, x):
        return self.net(x)

class GCL_basic(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self):
        super(GCL_basic, self).__init__()


    def edge_model(self, source, target, edge_attr):
        pass

    def node_model(self, h, edge_index, edge_attr):
        pass

    def forward(self, x, edge_index, edge_attr=None):
        row, col = edge_index
        edge_feat = self.edge_model(x[row], x[col], edge_attr)
        x = self.node_model(x, edge_index, edge_feat)
        return x, edge_feat

class GCL(GCL_basic):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_nf=0, act_fn=nn.ReLU(), bias=True, attention=False, t_eq=False, recurrent=True):
        super(GCL, self).__init__()
        self.attention = attention
        self.t_eq=t_eq
        self.recurrent = recurrent
        input_edge_nf = input_nf * 2
        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge_nf + edges_in_nf, hidden_nf, bias=bias),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf, bias=bias),
            act_fn)
        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(input_nf, hidden_nf, bias=bias),
                act_fn,
                nn.Linear(hidden_nf, 1, bias=bias),
                nn.Sigmoid())


        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf, hidden_nf, bias=bias),
            act_fn,
            nn.Linear(hidden_nf, output_nf, bias=bias))

        #if recurrent:
            #self.gru = nn.GRUCell(hidden_nf, hidden_nf)


    def edge_model(self, source, target, edge_attr):
        edge_in = torch.cat([source, target], dim=1)
        if edge_attr is not None:
            edge_in = torch.cat([edge_in, edge_attr], dim=1)
        out = self.edge_mlp(edge_in)
        if self.attention:
            att = self.att_mlp(torch.abs(source - target))
            out = out * att
        return out

    def node_model(self, h, edge_index, edge_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0))
        out = torch.cat([h, agg], dim=1)
        out = self.node_mlp(out)
        if self.recurrent:
            out = out + h
            #out = self.gru(out, h)
        return out

class GCL_rf(GCL_basic):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self, nf=64, edge_attr_nf=0, reg=0, act_fn=nn.LeakyReLU(0.2), clamp=False):
        super(GCL_rf, self).__init__()

        self.clamp = clamp
        layer = nn.Linear(nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
        self.phi = nn.Sequential(nn.Linear(edge_attr_nf + 1, nf),
                                 act_fn,
                                 layer)
        self.reg = reg

    def edge_model(self, source, target, edge_attr):
        x_diff = source - target
        radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1)
        e_input = torch.cat([radial, edge_attr], dim=1)
        e_out = self.phi(e_input)
        m_ij = x_diff * e_out
        if self.clamp:
            m_ij = torch.clamp(m_ij, min=-100, max=100)
        return m_ij

    def node_model(self, x, edge_index, edge_attr):
        row, col = edge_index
        agg = unsorted_segment_mean(edge_attr, row, num_segments=x.size(0))
        x_out = x + agg - x*self.reg
        return x_out

class E_GCL(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, clamp=False, norm_diff=False, tanh=False):
        super(E_GCL, self).__init__()
        input_edge = input_nf * 2
        self.coords_weight = coords_weight
        self.recurrent = recurrent
        self.attention = attention
        self.norm_diff = norm_diff
        self.tanh = tanh
        edge_coords_nf = 1


        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.clamp = clamp
        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
            self.coords_range = nn.Parameter(torch.ones(1))*3
        self.coord_mlp = nn.Sequential(*coord_mlp)

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

        #if recurrent:
        #    self.gru = nn.GRUCell(hidden_nf, hidden_nf)


    def edge_model(self, source, target, radial, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.recurrent:
            out = x + out
        return out, agg
    
    

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        trans = torch.clamp(trans, min=-100, max=100) #This is never activated but just in case it case it explosed it may save the train
        agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        coord += agg*self.coords_weight
        return coord


    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)

        if self.norm_diff:
            norm = torch.sqrt(radial) + 1
            coord_diff = coord_diff/(norm)

        return radial, coord_diff

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        # coord = self.node_coord_model(h, coord)
        # x = self.node_model(x, edge_index, x[col], u, batch)  # GCN
        return h, coord, edge_attr

class E_GCL_vel(E_GCL):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """


    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, norm_diff=False, tanh=False):
        E_GCL.__init__(self, input_nf, output_nf, hidden_nf, edges_in_d=edges_in_d, nodes_att_dim=nodes_att_dim, act_fn=act_fn, recurrent=recurrent, coords_weight=coords_weight, attention=attention, norm_diff=norm_diff, tanh=tanh)
        self.norm_diff = norm_diff
        self.coord_mlp_vel = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 1))

    def forward(self, h, edge_index, coord, vel, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)


        coord += self.coord_mlp_vel(h) * vel
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        # coord = self.node_coord_model(h, coord)
        # x = self.node_model(x, edge_index, x[col], u, batch)  # GCN
        return h, coord, edge_attr

class GCL_rf_vel(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """
    def __init__(self,  nf=64, edge_attr_nf=0, act_fn=nn.LeakyReLU(0.2), coords_weight=1.0):
        super(GCL_rf_vel, self).__init__()
        self.coords_weight = coords_weight
        self.coord_mlp_vel = nn.Sequential(
            nn.Linear(1, nf),
            act_fn,
            nn.Linear(nf, 1))

        layer = nn.Linear(nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
        #layer.weight.uniform_(-0.1, 0.1)
        self.phi = nn.Sequential(nn.Linear(1 + edge_attr_nf, nf),
                                 act_fn,
                                 layer,
                                 nn.ReLU()) #we had to add the tanh to keep this method stable

    def forward(self, x, vel_norm, vel, edge_index, edge_attr=None):
        row, col = edge_index
        edge_m = self.edge_model(x[row], x[col], edge_attr)
        x = self.node_model(x, edge_index, edge_m)
        x += vel * self.coord_mlp_vel(vel_norm)
        return x, edge_attr

    def edge_model(self, source, target, edge_attr):
        x_diff = source - target
        radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1)
        e_input = torch.cat([radial, edge_attr], dim=1)     # [radial, 0 or 1(edge), loc_dist] loc_dist = radial**2
        e_out = self.phi(e_input)
        m_ij = x_diff * e_out
        return m_ij

    def node_model(self, x, edge_index, edge_m):
        row, col = edge_index
        agg = unsorted_segment_mean(edge_m, row, num_segments=x.size(0))
        x_out = x + agg * self.coords_weight
        return x_out
    
class E_GCL_clof(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """
    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, clamp=False, norm_diff=False, tanh=False, out_basis_dim=1):
        super(E_GCL_clof, self).__init__()
        input_edge = input_nf * 2
        self.coords_weight = coords_weight
        self.recurrent = recurrent
        self.attention = attention
        self.norm_diff = norm_diff
        self.tanh = tanh
        edge_coords_nf = 1


        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, out_basis_dim, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.clamp = clamp
        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        coord_mlp.append(nn.ReLU())
        # if self.tanh:
        #     coord_mlp.append(nn.Tanh())
        self.coords_range = nn.Parameter(torch.ones(1))*3
        self.coord_mlp = nn.Sequential(*coord_mlp)


        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

        #if recurrent:
        #    self.gru = nn.GRUCell(hidden_nf, hidden_nf)


    def edge_model(self, source, target, radial, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.recurrent:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        trans = torch.clamp(trans, min=-100, max=100) #This is never activated but just in case it case it explosed it may save the train
        agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        coord += agg*self.coords_weight
        return coord


    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)

        if self.norm_diff:
            norm = torch.sqrt(radial) + 1
            coord_diff = coord_diff/(norm)

        return radial, coord_diff

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        # coord = self.node_coord_model(h, coord)
        # x = self.node_model(x, edge_index, x[col], u, batch)  # GCN
        return h, coord, edge_attr

class Clof_GCL(E_GCL_clof):
    """
    Basic message passing module of ClofNet.
    """
    """
    Basic message passing module of ClofNet.
    """
    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, norm_diff=False, tanh=False):
        E_GCL_clof.__init__(self, input_nf, output_nf, hidden_nf, edges_in_d=edges_in_d, nodes_att_dim=nodes_att_dim, act_fn=act_fn, recurrent=recurrent, coords_weight=coords_weight, attention=attention, norm_diff=norm_diff, tanh=tanh, out_basis_dim=3)
        self.norm_diff = norm_diff
        self.coord_mlp_vel = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 1))

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_nf * 2 + 1 + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)
        self.layer_norm = nn.LayerNorm(hidden_nf)

    def coord2localframe(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)
        coord_cross = torch.cross(coord[row], coord[col], dim=-1)
        if self.norm_diff:
            norm = torch.sqrt(radial) + 1
            coord_diff = coord_diff / norm
            cross_norm = (
                torch.sqrt(torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1
            coord_cross = coord_cross / cross_norm

        coord_vertical = torch.cross(coord_diff, coord_cross, dim=-1)

        return radial, coord_diff, coord_cross, coord_vertical

    def coord_model(self, coord, edge_index, coord_diff, coord_cross, coord_vertical, edge_feat):
        row, col = edge_index
        coff = self.coord_mlp(edge_feat)
        trans = coord_diff * coff[:, :1] + coord_cross * coff[:, 1:2] + coord_vertical * coff[:, 2:3]
        trans = torch.clamp(trans, min=-100, max=100) #This is never activated but just in case it case it explosed it may save the train
        agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        coord += agg*self.coords_weight
        return coord

    def forward(self, h, edge_index, coord, vel, edge_attr=None, node_attr=None):
        row, col = edge_index
        residue = h
        # h = self.layer_norm(h)
        
        radial, coord_diff, coord_cross, coord_vertical = self.coord2localframe(edge_index, coord)
        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, coord_cross, coord_vertical, edge_feat)
        
        coord += self.coord_mlp_vel(h) * vel
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        h = residue + h
        h = self.layer_norm(h)
        return h, coord, edge_attr

class CFIN_layer(nn.Module):

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, dropout_rate=0.5, act_fn=nn.ReLU(), recurrent=False, attention=False, clamp=False, norm_diff=False, tanh=False, order = 1, out_dim = 3):
        super(CFIN_layer, self).__init__()
        
        input_edge = input_nf * 2
        self.recurrent = recurrent
        self.attention = attention
        self.norm_diff = norm_diff
        self.tanh = tanh
        self.dropout_rate = dropout_rate
        
        if out_dim == 2:
            inter_invar_d = 2
        elif out_dim == 3:
            inter_invar_d = 3
        inner_invar_d = 3
        
        self.coord_mlp_vel = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            nn.LayerNorm(hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, out_dim),
            )
        
        self.coord_mlp_loc = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            nn.LayerNorm(hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, out_dim),
            )
        
        # velocity delta
        layer_v = nn.Linear(hidden_nf, out_dim * 2, bias=False)
        torch.nn.init.kaiming_uniform_(layer_v.weight, nonlinearity='relu')
        vel_mlp = []
        vel_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        vel_mlp.append(nn.LayerNorm(hidden_nf))
        vel_mlp.append(act_fn)
        vel_mlp.append(nn.Dropout(self.dropout_rate))
        vel_mlp.append(layer_v)
        self.vel_range = nn.Parameter(torch.ones(1)*0.01)
        self.vel_mlp = nn.Sequential(*vel_mlp)
        
        # coord delta
        layer = nn.Linear(hidden_nf, out_dim * 2, bias=False)
        torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(nn.LayerNorm(hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(nn.Dropout(self.dropout_rate))
        coord_mlp.append(layer)
        self.coords_range = nn.Parameter(torch.ones(1)*0.01)
        self.coord_mlp = nn.Sequential(*coord_mlp)
        
        # used for edge model
        self.edge_mlp = nn.Sequential(
            nn.Linear((input_nf+inner_invar_d) * 2 + edges_in_d + inter_invar_d, hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)
        self.edge_norm = nn.LayerNorm(hidden_nf)
        
        # used for node model
        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, output_nf),
            act_fn,
            nn.LayerNorm(output_nf)
            )

        self.clamp = clamp
        self.order = order
        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())
        self.layer_norm = nn.LayerNorm(hidden_nf)
                
    def edge_model(self, source, target, i_nt, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, i_nt], dim=-1)
        else:
            out = torch.cat([source, target, i_nt, edge_attr], dim=-1)
        out = self.edge_mlp(out)
        out = self.edge_norm(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out
    
    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) # (batch_size * num_agents, hidden_nf // 2)

        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=-1)
        else:
            agg = torch.cat([x, agg], dim=-1)
            
        out = self.node_mlp(agg)

        return out, agg 

    def complete_local_frames(self, loc, vel, edge_index):

        row, col = edge_index
        if loc.shape[-1] == 2: # 2-dim coord
            pass
            # compute a_i(t)
            # loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True) + 1 
            # frame_a = loc / loc_norm # a_i = p_i / ||p_i||_2
            
            # compute b_i(t)
            # cross1_x = vel[..., 1:2] * loc[..., 0:1] - vel[..., 0:1] * loc[..., 1:2] # compute v_i \cross p_i
            # cross2_x = cross1_x * loc[..., 1:2]
            # cross2_y = -cross1_x * loc[..., 0:1]
            # cross2 = torch.cat([cross2_x, cross2_y], dim=-1) # compute (v_i \cross p_i) \cross p_i
            # cross2_norm = torch.norm(cross2, p=2, dim=-1, keepdim=True) + 1
            # frame_b = cross2 / cross2_norm # compute b_i
            
            # return frame_a, frame_b
        elif loc.shape[-1] == 3: 
            
            loc_diff = loc[row] - loc[col]
            vel_diff = vel[row] - vel[col]
            
            loc_diff_sqr = torch.sum((loc_diff)**2, 1).unsqueeze(1)
            vel_diff_sqr = torch.sum((vel_diff)**2, 1).unsqueeze(1)
        
            loc_cross = torch.cross(loc[row], loc[col], dim = -1)
            vel_cross = torch.cross(vel[row], vel[col], dim = -1)
            
            loc_norm = torch.sqrt(loc_diff_sqr) + 1e-8
            vel_norm = torch.sqrt(vel_diff_sqr) + 1e-8
            loc_diff = loc_diff / loc_norm
            vel_diff = vel_diff / vel_norm
            
            loc_cross_norm = (torch.sqrt(torch.sum((loc_cross)**2, 1).unsqueeze(1))) + 1e-8
            vel_cross_norm = (torch.sqrt(torch.sum((vel_cross)**2, 1).unsqueeze(1))) + 1e-8
            loc_cross = loc_cross / loc_cross_norm
            vel_cross = vel_cross / vel_cross_norm
            
            loc_vertical = torch.cross(loc_diff, loc_cross, dim = -1)
            vel_vertical = torch.cross(vel_diff, vel_cross, dim = -1)
            loc_frm_a, loc_frm_b, loc_frm_c = loc_diff, loc_cross, loc_vertical
            vel_frm_a, vel_frm_b, vel_frm_c = vel_diff, vel_cross, vel_vertical
            
            return loc_frm_a, loc_frm_b, loc_frm_c, vel_frm_a, vel_frm_b, vel_frm_c
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")
    
    def coord2inner_diff_invar(self, loc, vel, norm):
        """coord to complete inner diff invariants
        """
        if norm == True:
            epsilon = 1e-8
            loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
            loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
            loc = loc/loc_norm
        
            vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
            vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
            vel = vel/vel_norm
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize* num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize* num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize* num_agents, 1)
            
        I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize* num_agents, 3)
            
        return I_nn
        
    def coord2inter_diff_invar(self, edge_index, loc, vel, norm):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        
        # normalization
        if norm == True:
            epsilon = 1e-6
            loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
            loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
            loc = loc/loc_norm
        
            vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
            vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
            vel = vel/vel_norm
        
        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute vi*(pi \cross pj)
            cross = torch.cross(loc[row], loc[col], dim=-1)
            inter_thr = torch.sum(vel[row] * cross, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            
            # compute (pi \cross pj)*(vi \cross vj)
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            
            I_nt = torch.cat([inter_loc, inter_vel, inter_thr], dim=-1)  # shape (num_edges * batchsize, 3)
            
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")
    
    def delta_model(self, loc, vel, edge_index, frame_loc_a, frame_loc_b, frame_loc_c, frame_vel_a, frame_vel_b, frame_vel_c, edge_feat):
        """
        delta model to compute the delta of loc and vel
        Parameters
        ----------
            coord:       (batchsize* num_agents, dim_coord)
            edge_index:  (2, batchsize* num_edges)
            edge_feat:   (batchsize* num_edges, hidden_nf)
            frame_loc_a: (batchsize* num_agents, dim_coord)
            frame_loc_b: (batchsize* num_agents, dim_coord)
            frame_loc_c: (batchsize* num_agents, dim_coord)
            frame_vel_a: (batchsize* num_agents, dim_coord)
            frame_vel_b: (batchsize* num_agents, dim_coord)
            frame_vel_c: (batchsize* num_agents, dim_coord)
        Returns
        -------

            delta_loc:   (batchsize* num_agents, dim_coord)
            delta_vel:   (batchsize* num_agents, dim_coord)
        """
        row, col = edge_index
        coff_loc = self.coord_mlp(edge_feat)
        # print(coff_loc)
        coff_vel = self.vel_mlp(edge_feat)
        trans_loc = frame_loc_a * coff_loc[:, :1] + frame_loc_b * coff_loc[:, 1:2] + frame_loc_c * coff_loc[:, 2:3] + frame_vel_a * coff_loc[:, 3:4] + frame_vel_b * coff_loc[:, 4:5] + frame_vel_c * coff_loc[:, 5:6] 
        trans_vel = frame_loc_a * coff_vel[:, :1] + frame_loc_b * coff_vel[:, 1:2] + frame_loc_c * coff_vel[:, 2:3] + frame_vel_a * coff_vel[:, 3:4] + frame_vel_b * coff_vel[:, 4:5] + frame_vel_c * coff_vel[:, 5:6] 
        agg_loc = unsorted_segment_sum(trans_loc, row, num_segments=loc.size(0))
        agg_vel = unsorted_segment_sum(trans_vel, row, num_segments=vel.size(0))
        
        return agg_loc, agg_vel

    
    def forward(self, h, edge_index, c, loc, vel, edge_attr=None, node_attr=None):
        """
            h:(batchsize* num_agents, input_nf)
            c:(batchsize* num_edges, 1)
            loc, vel: (batchsize* num_agents, dim_coord)
            edge_attr: (batchsize* num_edges, edges_in_d)   num_edges = num_agents * (num_agents - 1)
        """
        vel_old = vel
        row, col = edge_index
        h_res = h
        loc_frm_a, loc_frm_b, loc_frm_c, vel_frm_a, vel_frm_b, vel_frm_c = self.complete_local_frames(loc, vel, edge_index)

        i_nn = self.coord2inner_diff_invar(loc, vel, self.norm_diff) # shape (batchsize* num_agents, 3)
        i_nt = self.coord2inter_diff_invar(edge_index, loc, vel, self.norm_diff) # shape (num_edges * batchsize, 4)
        h_new_row = torch.cat([h[row], i_nn[row]], dim=-1) # shape (num_edges * batchsize, 67)
        h_new_col = torch.cat([h[col], i_nn[col]], dim=-1) # shape (num_edges * batchsize, 67)

        edge_feat = self.edge_model(h_new_row, h_new_col, i_nt, edge_attr)
        # print(edge_feat)
        edge_feat= c * edge_feat # infer edges weighted
        
        delta_loc, delta_vel = self.delta_model(loc, vel, edge_index, loc_frm_a, loc_frm_b, loc_frm_c, vel_frm_a, vel_frm_b, vel_frm_c, edge_feat)
        
        vel = vel +  self.coord_mlp_vel(h) * delta_vel * self.vel_range
        loc = loc +  self.coord_mlp_loc(h) * delta_loc * self.coords_range
        
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        h = h_res + h
        h = self.layer_norm(h)

        return h, loc, vel, edge_feat

class GMNLayer(nn.Module):
    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(),
                 recurrent=True, coords_weight=1.0, attention=False, norm_diff=False, tanh=False):
        """
        The layer of Graph Mechanics Networks.
        :param input_nf: input node feature dimension
        :param output_nf: output node feature dimension
        :param hidden_nf: hidden dimension
        :param edges_in_d: input edge dimension
        :param nodes_att_dim: attentional dimension, inherited
        :param act_fn: activation function
        :param recurrent: residual connection on x
        :param coords_weight: coords weight, inherited
        :param attention: use attention on edges, inherited
        :param norm_diff: normalize the distance, inherited
        :param tanh: Tanh activation, inherited
        """
        super(GMNLayer, self).__init__()
        input_edge = input_nf * 2
        self.coords_weight = coords_weight
        self.recurrent = recurrent
        self.attention = attention
        self.norm_diff = norm_diff
        self.tanh = tanh
        edge_coords_nf = 4

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim + hidden_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
            self.coords_range = nn.Parameter(torch.ones(1))*3
        self.coord_mlp = nn.Sequential(*coord_mlp)

        self.coord_mlp_vel = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 1))

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, multi_channel_vec, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, multi_channel_vec], dim=1)
        else:
            out = torch.cat([source, target, multi_channel_vec, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr, others=None):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        if others is not None:  # can concat h here
            agg = torch.cat([others, agg], dim=1)
        out = self.node_mlp(agg)
        if self.recurrent:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        trans = torch.clamp(trans, min=-100, max=100)
        agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        coord += agg * self.coords_weight
        return coord

    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum(coord_diff ** 2, 1).unsqueeze(1)

        if self.norm_diff:
            norm = torch.sqrt(radial) + 1
            coord_diff = coord_diff / norm

        return radial, coord_diff

    
    def multi_channel_vec(self, edge_index, coord, vel):
        row, col = edge_index
        V = torch.stack([coord, vel], dim=2)        # (batchsize* num_edges, dim_coord, 2)
        V_ij = V[row] - V[col]                      # (batchsize* num_edges, dim_coord, 2)
        
        transpose_dot_product = torch.matmul(V_ij.transpose(1, 2), V_ij)    # V_ij.T @ V_ij     (batchsize* num_edges, 2, 2)
        frobenius_norm = torch.norm(transpose_dot_product, p='fro', dim=(1, 2)).unsqueeze(1).unsqueeze(2)   # Frobenius (batchsize* num_edges, 1, 1)
        normalized_matrix = transpose_dot_product / frobenius_norm  # (batchsize* num_edges, 2, 2)

        flattened_matrix = torch.reshape(normalized_matrix, (normalized_matrix.shape[0], -1))   # (batchsize* num_edges, 4)
        
        # print("V shape in GMNLayer", V_ij.shape)
        # print("V_ij shape in GMNLayer", V_ij.shape)
        # print("transpose_dot_product shape in GMNLayer", transpose_dot_product.shape)
        # print("frobenius_norm shape in GMNLayer", frobenius_norm.shape)
        # print("normalized_matrix shape in GMNLayer", normalized_matrix.shape)
        # while True:
        #     pass

        return flattened_matrix


    def forward(self, h, edge_index, x, v, edge_attr=None, node_attr=None):
        """
        :param h: the node aggregated feature  [N, n_hidden]
        :param edge_index:  [2, M], M is the number of edges
        :param x: input coordinate  [N, 3]
        :param v: input velocity  [N, 3]
        :param edge_attr: edge feature  [M, n_edge]
        :param node_attr: the node input feature  [N, n_in]
        :return: the updated h, x, v, and edge_attr
        """
        # aggregate force (equivariant message passing on the whole graph)
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, x)
        multi_channel_vec = self.multi_channel_vec(edge_index, x, v)
        edge_feat = self.edge_model(h[row], h[col], multi_channel_vec, edge_attr)  # [B*M, Ef], the global invariant message
        
        coord = self.coord_model(x, edge_index, coord_diff, edge_feat)  # [B*N, 3]
        coord += self.coord_mlp_vel(h) * v
        
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr, others=h)

        return h, coord, edge_attr
    
class CFIN_layer_diff_invar(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """
    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, dropout_rate=0.5, act_fn=nn.ReLU(), recurrent=False, attention=False, clamp=False, norm_diff=False, tanh=False, order = 1, out_dim = 3, case = 1, frame = 0):
        super(CFIN_layer_diff_invar, self).__init__()
        
        input_edge = input_nf * 2
        self.recurrent = recurrent
        self.attention = attention
        self.norm_diff = norm_diff
        self.tanh = tanh
        self.dropout_rate = dropout_rate
        self.case = case
        self.frame = frame # 0 radial 1 positional 2 differential
        
        if self.case == 1: # only use pipj
            inter_invar_d = 1
            inner_invar_d = 0
        elif self.case == 2: # only use ||pi-pj||
            inter_invar_d = 1
            inner_invar_d = 0
        elif self.case == 3: # ||pi-pj|| ||pi|| and ||pj||
            inter_invar_d = 1
            inner_invar_d = 1
        elif self.case == 4: # ||pi-pj|| ||pi|| and ||pj||, ||vi-vj|| ||vi|| and ||vj||
            inter_invar_d = 2
            inner_invar_d = 2
        elif self.case == 5: # CFIN
            inter_invar_d = 3
            inner_invar_d = 3
        else: # numerous different invariants
            inter_invar_d = 13
            inner_invar_d = 4
        
        self.coord_mlp_vel = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 1))
        
        self.coord_mlp_loc = nn.Sequential(
            nn.Linear(input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 1))
        
        # velocity delta
        layer_v = nn.Linear(hidden_nf, 6, bias=False)
        torch.nn.init.kaiming_uniform_(layer_v.weight, nonlinearity='relu')
        vel_mlp = []
        vel_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        vel_mlp.append(act_fn)
        vel_mlp.append(layer_v)
        self.vel_range = nn.Parameter(torch.ones(1))*3
        self.vel_mlp = nn.Sequential(*vel_mlp)
        
        # coord delta
        layer = nn.Linear(hidden_nf, 6, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        self.coords_range = nn.Parameter(torch.ones(1))*3
        self.coord_mlp = nn.Sequential(*coord_mlp)
        
        # used for edge model
        self.edge_mlp = nn.Sequential(
            nn.Linear((input_nf+inner_invar_d) * 2 + edges_in_d + inter_invar_d, hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)
        
        # used for node model
        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Dropout(self.dropout_rate),
            nn.Linear(hidden_nf, output_nf))

        self.clamp = clamp
        self.order = order
        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())
        self.layer_norm = nn.LayerNorm(hidden_nf)

    def edge_model(self, source, target, i_nt, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, i_nt], dim=-1)
        else:
            out = torch.cat([source, target, i_nt, edge_attr], dim=-1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out
    
    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) #(batch_size * num_agents, hidden_nf // 2)

        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=-1)
        else:
            agg = torch.cat([x, agg], dim=-1)
            
        out = self.node_mlp(agg)

        return out, agg 

    def complete_local_frames(self, loc, vel, edge_index):

        row, col = edge_index
        if loc.shape[-1] == 2: # 2-dim coord
            pass
            # compute a_i(t)
            # loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True) + 1 
            # frame_a = loc / loc_norm # a_i = p_i / ||p_i||_2
            
            # compute b_i(t)
            # cross1_x = vel[..., 1:2] * loc[..., 0:1] - vel[..., 0:1] * loc[..., 1:2] # compute v_i \cross p_i
            # cross2_x = cross1_x * loc[..., 1:2]
            # cross2_y = -cross1_x * loc[..., 0:1]
            # cross2 = torch.cat([cross2_x, cross2_y], dim=-1) # compute (v_i \cross p_i) \cross p_i
            # cross2_norm = torch.norm(cross2, p=2, dim=-1, keepdim=True) + 1
            # frame_b = cross2 / cross2_norm # compute b_i
            
            # return frame_a, frame_b
        elif loc.shape[-1] == 3: # 3-dim coord
            
            loc_diff = loc[row] - loc[col]
            vel_diff = vel[row] - vel[col]
            
            loc_diff_sqr = torch.sum((loc_diff)**2, 1).unsqueeze(1)
            vel_diff_sqr = torch.sum((vel_diff)**2, 1).unsqueeze(1)
        
            loc_cross = torch.cross(loc[row], loc[col], dim = -1)
            vel_cross = torch.cross(vel[row], vel[col], dim = -1)
            
            #normalization
            loc_norm = torch.sqrt(loc_diff_sqr) + 1
            vel_norm = torch.sqrt(vel_diff_sqr) + 1
            loc_diff = loc_diff / loc_norm
            vel_diff = vel_diff / vel_norm
            
            loc_cross_norm = (torch.sqrt(torch.sum((loc_cross)**2, 1).unsqueeze(1))) + 1
            vel_cross_norm = (torch.sqrt(torch.sum((vel_cross)**2, 1).unsqueeze(1))) + 1
            loc_cross = loc_cross / loc_cross_norm
            vel_cross = vel_cross / vel_cross_norm
            
            loc_vertical = torch.cross(loc_diff, loc_cross, dim = -1)
            vel_vertical = torch.cross(vel_diff, vel_cross, dim = -1)
            loc_frm_a, loc_frm_b, loc_frm_c = loc_diff, loc_cross, loc_vertical
            vel_frm_a, vel_frm_b, vel_frm_c = vel_diff, vel_cross, vel_vertical
            

            return loc_frm_a, loc_frm_b, loc_frm_c, vel_frm_a, vel_frm_b, vel_frm_c
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")
    
    def coord2inner_diff_invar(self, loc, vel, norm):
        """coord to complete inner diff invariants
        """
        # normalization
        if norm == True:
            epsilon = 1e-8
            loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True)
            loc_norm = torch.max(loc_norm, torch.tensor(epsilon, dtype=loc_norm.dtype, device=loc_norm.device)) # prevent ||loc|| = 0
            loc = loc/loc_norm
        
            vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True)
            vel_norm = torch.max(vel_norm, torch.tensor(epsilon, dtype=vel_norm.dtype, device=vel_norm.device)) # prevent ||vel|| = 0
            vel = vel/vel_norm
    
        # compute ri rj vi vj pivi pjvj
        radius = torch.sum(loc ** 2, dim=-1, keepdim=True) #ri rj shape (batchsize* num_agents, 1)
        velocity = torch.sum(vel ** 2, dim=-1, keepdim=True) #vi vj shape (batchsize* num_agents, 1)
        inner = torch.sum(loc * vel, dim=-1, keepdim=True) # pivi pjvj shape (batchsize* num_agents, 1)
        cross = torch.cross(loc, vel, dim=-1)
        cross_sqr = torch.sum(cross**2, dim=-1, keepdim=True) # (pixvi)**2 (pjxvj)**2 shape (batchsize* num_agents, 1)
        if self.case == 1:
            I_nn = None
        elif self.case == 2:
            I_nn = None
        elif self.case == 3:
            I_nn = radius #ri rj shape (batchsize* num_agents, 1)
        elif self.case == 4:
            I_nn = torch.cat([radius, velocity], dim=-1)  # shape (batchsize* num_agents, 2)
        elif self.case == 5:
            I_nn = torch.cat([radius, velocity, inner], dim=-1)  # shape (batchsize* num_agents, 3)
        else:
            I_nn = torch.cat([radius, velocity, inner, cross_sqr], dim=-1)  # shape (batchsize* num_agents, 4)
        return I_nn
        
    def coord2inter_diff_invar(self, edge_index, loc, vel, norm):
        """coord to complete inter diff invariants
        """
        row, col = edge_index
        if norm == True:
            # normalization
            loc_norm = torch.norm(loc, p=2, dim=-1,keepdim = True) + 1
            loc = loc/loc_norm
        
            vel_norm = torch.norm(vel, p=2, dim=-1,keepdim = True) + 1
            vel = vel/vel_norm

        if loc.shape[-1] == 2: # 2-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            I_nt = torch.cat([inter_loc, inter_vel], dim=-1)  # shape (num_edges * batchsize, 2)
            
            return I_nt
        
        elif loc.shape[-1] == 3: # 3-dim coord
            # compute pipj vivj
            inter_loc = torch.sum(loc[row] * loc[col], dim=-1, keepdim=True) #pipj shape (num_edges * batchsize, 1)
            inter_vel = torch.sum(vel[row] * vel[col], dim=-1, keepdim=True) #vivj shape (num_edges * batchsize, 1)
            
            # compute relative distance (pi-pj)^2 (vi-vj)^2
            loc_dist = torch.sum((loc[row]-loc[col])**2, dim=-1, keepdim=True)
            vel_dist = torch.sum((vel[row]-vel[col])**2, dim=-1, keepdim=True)
            
            # compute cross
            cross_loc = torch.cross(loc[row], loc[col], dim=-1)
            cross_vel = torch.cross(vel[row], vel[col], dim=-1)
            cross_row = torch.cross(loc[row], vel[row], dim=-1)
            cross_col = torch.cross(loc[col], vel[col], dim=-1)
            
            # compute interthr
            inter_thr_1 = torch.sum(vel[row] * cross_loc, dim=-1, keepdim=True) #vi*(pi x pj) shape (num_edges * batchsize, 1)
            inter_thr_2 = torch.sum(vel[col] * cross_loc, dim=-1, keepdim=True) #vj*(pi x pj) shape (num_edges * batchsize, 1)
            inter_thr_3 = torch.sum(loc[row] * cross_vel, dim=-1, keepdim=True) #pi*(vi x vj) shape (num_edges * batchsize, 1)
            inter_thr_4 = torch.sum(loc[col] * cross_vel, dim=-1, keepdim=True) #pj*(vi x vj) shape (num_edges * batchsize, 1)
            inter_thr_5 = torch.sum(loc[col] * cross_row, dim=-1, keepdim=True) #pj*(pi x vi) shape (num_edges * batchsize, 1)
            inter_thr_6 = torch.sum(vel[col] * cross_row, dim=-1, keepdim=True) #vj*(pi x vi) shape (num_edges * batchsize, 1)
            inter_thr_7 = torch.sum(loc[row] * cross_col, dim=-1, keepdim=True) #pi*(pj x vj) shape (num_edges * batchsize, 1)
            inter_thr_8 = torch.sum(vel[row] * cross_col, dim=-1, keepdim=True) #vi*(pj x vj) shape (num_edges * batchsize, 1)
            inter_thr = torch.cat([inter_thr_1, inter_thr_2, inter_thr_3, inter_thr_4, inter_thr_5, inter_thr_6, inter_thr_7, inter_thr_8], dim=-1)#shape (num_edges * batchsize, 8)
            # compute interfour
            inter_four = torch.sum(cross_loc * cross_vel, dim=-1, keepdim=True) #(pi x pj)*(vi x vj) shape (num_edges * batchsize, 1)
            
            if self.case == 1:
                I_nt = inter_loc # shape (batchsize* num_agents, 1)
            elif self.case == 2:
                I_nt = loc_dist # shape (batchsize* num_agents, 1)
            elif self.case == 3:
                I_nt = loc_dist # shape (batchsize* num_agents, 1)
            elif self.case == 4:
                I_nt = torch.cat([loc_dist, vel_dist], dim=-1)  # shape (batchsize* num_agents, 2)
            elif self.case == 5:
                I_nt = torch.cat([inter_loc, inter_vel, inter_thr_1], dim=-1)  # shape (num_edges * batchsize, 3)
            else:
                I_nt = torch.cat([inter_loc, inter_vel, inter_thr, inter_four, loc_dist, vel_dist], dim=-1)  # shape (num_edges * batchsize, 13)
            return I_nt
        else:
            raise ValueError(f"Unsupported coordinate dimension: {loc.shape[-1]}. Only 2D and 3D coordinates are supported.")
    
    def delta_model(self, loc, vel, edge_index, frame_loc_a, frame_loc_b, frame_loc_c, frame_vel_a, frame_vel_b, frame_vel_c, edge_feat):
        """
        delta model to compute the delta of loc and vel
        Parameters
        ----------
            coord:       (batchsize* num_agents, dim_coord)
            edge_index:  (2, batchsize* num_edges)
            edge_feat:   (batchsize* num_edges, hidden_nf)
            frame_loc_a: (batchsize* num_agents, dim_coord)
            frame_loc_b: (batchsize* num_agents, dim_coord)
            frame_loc_c: (batchsize* num_agents, dim_coord)
            frame_vel_a: (batchsize* num_agents, dim_coord)
            frame_vel_b: (batchsize* num_agents, dim_coord)
            frame_vel_c: (batchsize* num_agents, dim_coord)
        Returns
        -------
            delta_loc:   (batchsize* num_agents, dim_coord)
            delta_vel:   (batchsize* num_agents, dim_coord)
        """
        row, col = edge_index
        coff_loc = self.coord_mlp(edge_feat)
        coff_vel = self.vel_mlp(edge_feat)
        trans_loc = frame_loc_a * coff_loc[:, :1] + frame_loc_b * coff_loc[:, 1:2] + frame_loc_c * coff_loc[:, 2:3] + frame_vel_a * coff_loc[:, 3:4] + frame_vel_b * coff_loc[:, 4:5] + frame_vel_c * coff_loc[:, 5:6] 
        trans_vel = frame_loc_a * coff_vel[:, :1] + frame_loc_b * coff_vel[:, 1:2] + frame_loc_c * coff_vel[:, 2:3] + frame_vel_a * coff_vel[:, 3:4] + frame_vel_b * coff_vel[:, 4:5] + frame_vel_c * coff_vel[:, 5:6] 
        agg_loc = unsorted_segment_sum(trans_loc, row, num_segments=loc.size(0))
        agg_vel = unsorted_segment_sum(trans_vel, row, num_segments=vel.size(0))
        
        return agg_loc, agg_vel
    
    def forward(self, h, edge_index, c, loc, vel, edge_attr=None, node_attr=None):
        """
            h:(batchsize* num_agents, input_nf)
            c:(batchsize* num_edges, 1)
            loc, vel: (batchsize* num_agents, dim_coord)
            edge_attr: (batchsize* num_edges, edges_in_d)   num_edges = num_agents * (num_agents - 1)
        """
        vel_old = vel
        row, col = edge_index

        i_nn = self.coord2inner_diff_invar(loc, vel, self.norm_diff) # shape (batchsize* num_agents, 3)
        i_nt = self.coord2inter_diff_invar(edge_index, loc, vel, self.norm_diff) # shape (num_edges * batchsize, 4)
        if i_nn is None:
            h_new_row = h[row]
            h_new_col = h[col]
        else: 
            h_new_row = torch.cat([h[row], i_nn[row]], dim=-1) # shape (num_edges * batchsize, 64+inner_invar_d)
            h_new_col = torch.cat([h[col], i_nn[col]], dim=-1) # shape (num_edges * batchsize, 64+inner_invar_d)
        
        edge_feat = self.edge_model(h_new_row, h_new_col, i_nt, edge_attr)
        edge_feat= c * edge_feat # infer edges weighted
        
        loc_frm_a, loc_frm_b, loc_frm_c, vel_frm_a, vel_frm_b, vel_frm_c = self.complete_local_frames(loc, vel, edge_index)
        delta_loc, delta_vel = self.delta_model(loc, vel, edge_index, loc_frm_a, loc_frm_b, loc_frm_c, vel_frm_a, vel_frm_b, vel_frm_c, edge_feat)
        vel = vel +  self.coord_mlp_vel(h) * delta_vel
        loc = loc +  self.coord_mlp_loc(h) * delta_loc
        
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)

        return h, loc, vel, edge_feat

def unsorted_segment_sum(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)