import sys
import numpy as np
import torch
sys.path.append('.')


from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class HatLayer(torch.nn.Module):
    def __init__(self, algebra_type='sl3'):
        super(HatLayer, self).__init__()
        if algebra_type == 'so3':
            Ex = torch.Tensor([[0, 0, 0],
                            [0, 0, -1],
                            [0, 1, 0]])
            Ey = torch.Tensor([[0, 0, 1],
                            [0, 0, 0],
                            [-1, 0, 0]])
            Ez = torch.Tensor([[0, -1, 0],
                            [1, 0, 0],
                            [0, 0, 0]])

            E_bases = torch.stack(
                [Ex, Ey, Ez], dim=0)  # [3,3,3]
            self.register_buffer('E_bases', E_bases)
        
        elif algebra_type == 'gl2':
            # Basis for gl(2, R): standard basis of 2x2 matrices (4 elements)
            E1 = torch.Tensor([[1, 0], [0, 0]])  # E_11
            E2 = torch.Tensor([[0, 1], [0, 0]])  # E_12
            E3 = torch.Tensor([[0, 0], [1, 0]])  # E_21
            E4 = torch.Tensor([[0, 0], [0, 1]])  # E_22
            E_bases = torch.stack([E1, E2, E3, E4], dim=0)  # [4, 2, 2]
            self.register_buffer('E_bases', E_bases)
            
        elif algebra_type == 'gl3':
            # Basis for gl(3, R): standard basis of 3x3 matrices (9 elements)
            E1 = torch.Tensor([[1, 0, 0], [0, 0, 0], [0, 0, 0]])  # E_11
            E2 = torch.Tensor([[0, 1, 0], [0, 0, 0], [0, 0, 0]])  # E_12
            E3 = torch.Tensor([[0, 0, 1], [0, 0, 0], [0, 0, 0]])  # E_13
            E4 = torch.Tensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]])  # E_21
            E5 = torch.Tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]])  # E_22
            E6 = torch.Tensor([[0, 0, 0], [0, 0, 1], [0, 0, 0]])  # E_23
            E7 = torch.Tensor([[0, 0, 0], [0, 0, 0], [1, 0, 0]])  # E_31
            E8 = torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 1, 0]])  # E_32
            E9 = torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 1]])  # E_33
            E_bases = torch.stack([E1, E2, E3, E4, E5, E6, E7, E8, E9], dim=0)  # [9, 3, 3]
            self.register_buffer('E_bases', E_bases)
        
        elif algebra_type == 'gl4':
            # Basis for gl(4, R): standard basis of 4x4 matrices (16 elements)
            E1 = torch.Tensor([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E11
            E2 = torch.Tensor([[0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E12
            E3 = torch.Tensor([[0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E13
            E4 = torch.Tensor([[0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E14
            E5 = torch.Tensor([[0, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E21
            E6 = torch.Tensor([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E22
            E7 = torch.Tensor([[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]])  # E23
            E8 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0]])  # E24
            E9 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]])  # E31
            E10 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0]])  # E32
            E11 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]])  # E33
            E12 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0]])  # E34
            E13 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0]])  # E41
            E14 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 0, 0]])  # E42
            E15 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]])  # E43
            E16 = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]])  # E44
            E_bases = torch.stack([E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11, E12, E13, E14, E15, E16], dim=0)  # [16, 4, 4]
            self.register_buffer('E_bases', E_bases)
            
        elif algebra_type == 'sl3':
            E1 = torch.Tensor([[1, 0, 0],
                            [0, -1, 0],
                            [0, 0, 0]])
            E2 = torch.Tensor([[0, 1, 0],
                            [1, 0, 0],
                            [0, 0, 0]])
            E3 = torch.Tensor([[0, -1, 0],
                            [1, 0, 0],
                            [0, 0, 0]])
            E4 = torch.Tensor([[1, 0, 0],
                            [0, 1, 0],
                            [0, 0, -2]])
            E5 = torch.Tensor([[0, 0, 1],
                            [0, 0, 0],
                            [0, 0, 0]])
            E6 = torch.Tensor([[0, 0, 0],
                            [0, 0, 1],
                            [0, 0, 0]])
            E7 = torch.Tensor([[0, 0, 0],
                            [0, 0, 0],
                            [1, 0, 0]])
            E8 = torch.Tensor([[0, 0, 0],
                            [0, 0, 0],
                            [0, 1, 0]])
            E_bases = torch.stack(
                [E1, E2, E3, E4, E5, E6, E7, E8], dim=0)  # [8,3,3]
            
            self.register_buffer('E_bases', E_bases)
        
        elif algebra_type == 'sl4':
            # Basis for sl(4, R): 15 elements
            # Cartan subalgebra (3 elements)
            E1 = torch.Tensor([[1, 0, 0, 0],
                            [0, -1, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # H1: E11 - E22
            E2 = torch.Tensor([[0, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, -1, 0],
                            [0, 0, 0, 0]])  # H2: E22 - E33
            E3 = torch.Tensor([[0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 1, 0],
                            [0, 0, 0, -1]])  # H3: E33 - E44
            # Off-diagonal root vectors (12 elements)
            E4 = torch.Tensor([[0, 1, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # E12
            E5 = torch.Tensor([[0, 0, 1, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # E13
            E6 = torch.Tensor([[0, 0, 0, 1],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # E14
            E7 = torch.Tensor([[0, 0, 0, 0],
                            [1, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # E21
            E8 = torch.Tensor([[0, 0, 0, 0],
                            [0, 0, 1, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # E23
            E9 = torch.Tensor([[0, 0, 0, 0],
                            [0, 0, 0, 1],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])  # E24
            E10 = torch.Tensor([[0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [1, 0, 0, 0],
                                [0, 0, 0, 0]])  # E31
            E11 = torch.Tensor([[0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 1, 0, 0],
                                [0, 0, 0, 0]])  # E32
            E12 = torch.Tensor([[0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 0, 0, 1],
                                [0, 0, 0, 0]])  # E34
            E13 = torch.Tensor([[0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [1, 0, 0, 0]])  # E41
            E14 = torch.Tensor([[0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 1, 0, 0]])  # E42
            E15 = torch.Tensor([[0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 0, 0, 0],
                                [0, 0, 1, 0]])  # E43
            E_bases = torch.stack([E1, E2, E3, E4, E5, E6, E7, E8, E9, 
                                E10, E11, E12, E13, E14, E15], dim=0)  # [15, 4, 4]
            self.register_buffer('E_bases', E_bases)
            
        elif algebra_type == 'se3':
            # use the order of v = [t, \omega]^T

            E1 = torch.Tensor([[0, 0, 0, 1],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])
            E2 = torch.Tensor([[0, 0, 0, 0],
                            [0, 0, 0, 1],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])
            E3 = torch.Tensor([[0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 1],
                            [0, 0, 0, 0]])
            E4 = torch.Tensor([[0, 0, 0, 0],
                            [0, 0, -1, 0],
                            [0, 1, 0, 0],
                            [0, 0, 0, 0]])
            E5 = torch.Tensor([[0, 0, 1, 0],
                            [0, 0, 0, 0],
                            [-1, 0, 0, 0],
                            [0, 0, 0, 0]])
            E6 = torch.Tensor([[0, -1, 0, 0],
                            [1, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])

            E_bases = torch.stack(
                [E1,E2,E3,E4,E5,E6], dim=0)  # [6,3,3]
            self.register_buffer('E_bases', E_bases)

        elif algebra_type == 'sp4':
            E1 = torch.tensor([[1, 0, 0, 0], 
                   [0, 0, 0, 0],
                   [0, 0, -1, 0],
                   [0, 0, 0, 0]])

            E2 = torch.tensor([[0, 1, 0, 0],
                            [0, 0, 0, 0], 
                            [0, 0, 0, 0],
                            [0, 0, -1, 0]])

            E3 = torch.tensor([[0, 0, 0, 0],
                            [1, 0, 0, 0],
                            [0, 0, 0, -1],
                            [0, 0, 0, 0]])

            E4 = torch.tensor([[0, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, -1]])

            E5 = torch.tensor([[0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [1, 0, 0, 0],
                            [0, 0, 0, 0]])

            E6 = torch.tensor([[0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 1, 0, 0],
                            [1, 0, 0, 0]])

            E7 = torch.tensor([[0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 1, 0, 0]])

            E8 = torch.tensor([[0, 0, 1, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])

            E9 = torch.tensor([[0, 0, 0, 1],
                            [0, 0, 1, 0],
                            [0, 0, 0, 0],
                            [0, 0, 0, 0]])

            E10 = torch.tensor([[0, 0, 0, 0],
                                [0, 0, 0, 1],
                                [0, 0, 0, 0],
                                [0, 0, 0, 0]])
            E_bases = torch.stack(
                [E1,E2,E3,E4,E5,E6,E7,E8,E9,E10], dim=0)  # [10,3,3]
            self.register_buffer('E_bases', E_bases)
        else:
            raise ValueError('Invalid algebra type for the hat operation')

    def forward(self, v):
        """
        v: a tensor of arbitrary shape with the last dimension of size k
        """
        return (v[..., None, None]*self.E_bases).sum(dim=-3)

def vee(M, algebra_type='sl3'):
    if algebra_type == 'so3':
        return vee_so3(M)
    elif algebra_type == 'sl3':
        return vee_sl3(M)
    elif algebra_type == 'sl4':
        return vee_sl4(M)
    elif algebra_type == 'gl2':
        return vee_gl2(M)
    elif algebra_type == 'gl3':
        return vee_gl3(M)
    elif algebra_type == 'gl4':
        return vee_gl4(M)
    elif algebra_type == 'se3':
        return vee_se3(M)
    elif algebra_type == 'sp4':
        return vee_sp4(M)
    else:
        raise ValueError('Invalid algebra type for the vee operation')

def vee_so3(M):
    # [0 , -z, y ]
    # [z ,  0, -x]
    # [-y,  x, 0 ]
    v = torch.zeros(M.shape[:-2]+(3,)).to(M.device)
    v[..., 0] = M[..., 2, 1]
    v[..., 1] = M[..., 0, 2]
    v[..., 2] = M[..., 1, 0]
    return v

def vee_sl3(M):
    # [a1 + a4, a2 - a3,    a5]
    # [a2 + a3, a4 - a1,    a6]
    # [     a7,      a8, -2*a4]
    v = torch.zeros(M.shape[:-2]+(8,)).to(M.device)
    v[..., 3] = -0.5*M[..., 2, 2]
    v[..., 4] = M[..., 0, 2]
    v[..., 5] = M[..., 1, 2]
    v[..., 6] = M[..., 2, 0]
    v[..., 7] = M[..., 2, 1]
    v[..., 0] = (M[..., 0, 0] - v[..., 3])

    v[..., 1] = 0.5*(M[..., 0, 1] + M[..., 1, 0])
    v[..., 2] = 0.5*(M[..., 1, 0] - M[..., 0, 1])
    return v

def vee_sl4(M):
    """
    Maps a 4x4 traceless matrix (sl(4, R)) to a 15D vector based on a basis:
    H1 = E11 - E22, H2 = E22 - E33, H3 = E33 - E44, and E_ij (i ≠ j).
    Basis order: H1, H2, H3, E12, E13, E14, E21, E23, E24, E31, E32, E34, E41, E42, E43
    """
    v = torch.zeros(M.shape[:-2] + (15,), device=M.device)
    # Diagonal basis elements (H1, H2, H3)
    v[..., 0] = M[..., 0, 0]  # E1 coefficient
    v[..., 1] = (M[..., 1, 1] + v[..., 0])  # E2 coefficient
    v[..., 2] = (M[..., 2, 2] + v[..., 1])  # E3 coefficient
    # Note: M[3,3] = -M[0,0] - M[1,1] - M[2,2] due to trace=0, not stored
    # Off-diagonal elements
    v[..., 3] = M[..., 0, 1]  # E4
    v[..., 4] = M[..., 0, 2]  # E5
    v[..., 5] = M[..., 0, 3]  # E6
    v[..., 6] = M[..., 1, 0]  # E7
    v[..., 7] = M[..., 1, 2]  # E8
    v[..., 8] = M[..., 1, 3]  # E9
    v[..., 9] = M[..., 2, 0]  # E10
    v[..., 10] = M[..., 2, 1] # E11
    v[..., 11] = M[..., 2, 3] # E12
    v[..., 12] = M[..., 3, 0] # E13
    v[..., 13] = M[..., 3, 1] # E14
    v[..., 14] = M[..., 3, 2] # E15
    return v

def vee_gl2(M):
    """
    Maps a 2x2 matrix (gl(2, R)) to a 4D vector based on the standard basis E_ij.
    Basis order: E11, E12, E21, E22.
    """
    v = torch.zeros(M.shape[:-2] + (4,), device=M.device)
    idx = 0
    for i in range(2):
        for j in range(2):
            v[..., idx] = M[..., i, j]
            idx += 1
    return v

def vee_gl3(M):
    """
    Maps a 3x3 matrix (gl(3, R)) to a 9D vector based on the standard basis E_ij.
    Basis order: E11, E12, E13, E21, ..., E33.
    """
    v = torch.zeros(M.shape[:-2] + (9,), device=M.device)
    idx = 0
    for i in range(3):
        for j in range(3):
            v[..., idx] = M[..., i, j]
            idx += 1
    return v

def vee_gl4(M):
    """
    Maps a 4x4 matrix (gl(4, R)) to a 16D vector based on the standard basis E_ij.
    Basis order: E11, E12, E13, E14, E21, ..., E44.
    """
    v = torch.zeros(M.shape[:-2] + (16,), device=M.device)
    idx = 0
    for i in range(4):
        for j in range(4):
            v[..., idx] = M[..., i, j]
            idx += 1
    return v


def vee_se3(M):
    # [0 ,  -wz,  wy,  tx]
    # [wz ,   0, -wx,  ty]
    # [-wy,  wx,   0,  tz]
    # [  0,   0,   0,   0]
    v = torch.zeros(M.shape[:-2]+(6,)).to(M.device)

    v[..., 0] = M[..., 0, 3]
    v[..., 1] = M[..., 1, 3]
    v[..., 2] = M[..., 2, 3]
    v[..., 3] = M[..., 2, 1]
    v[..., 4] = M[..., 0, 2]
    v[..., 5] = M[..., 1, 0]
    return v

def vee_sp4(M):
    v = torch.zeros(M.shape[:-2]+(10,)).to(M.device)

    v[..., 0] = M[..., 0, 0]
    v[..., 1] = M[..., 0, 1]
    v[..., 2] = M[..., 1, 0]
    v[..., 3] = M[..., 1, 1]
    v[..., 4] = M[..., 2, 0]
    v[..., 5] = M[..., 2, 1]
    v[..., 6] = M[..., 3, 1]
    v[..., 7] = M[..., 0, 2]
    v[..., 8] = M[..., 1, 2]
    v[..., 9] = M[..., 1, 3]
    return v

def killingform(x_hat, d_hat, algebra_type='sl3', feature_wise=False):
    if algebra_type == 'so3':
        return killingform_so3(x_hat, d_hat, feature_wise)
    elif algebra_type == 'sl3':
        return killingform_sl3(x_hat, d_hat, feature_wise)
    elif algebra_type == 'sl4':
        return killingform_sl4(x_hat, d_hat, feature_wise)
    elif algebra_type == 'sp4':
        return killingform_sp4(x_hat, d_hat, feature_wise)
    elif algebra_type == 'gl2':
        return killingform_gl2(x_hat, d_hat, feature_wise)
    elif algebra_type == 'gl3':
        return killingform_gl3(x_hat, d_hat, feature_wise)
    elif algebra_type == 'gl4':
        return killingform_gl4(x_hat, d_hat, feature_wise)
    else:
        raise ValueError('Invalid algebra type for the Killing form')

def killingform_gl2(x_hat, d_hat, feature_wise=False):
    """
    x_hat: tensor with last two dimensions 3x3 representing elements in gl(3, R)
    d_hat: tensor with last two dimensions 3x3 representing elements in gl(3, R)
    Bilinear form for gl3 is 6*tr(x_hat d_hat) - tr(x_hat) tr(d_hat)
    """
    print(x_hat.shape, d_hat.shape) #[1, 3, 100, 2, 2]
    print(x_hat[0,0,0,:,:])
    print(d_hat[0,0,0,:,:])
    if not feature_wise:
        tr_xy = (x_hat.transpose(-1, -2) * d_hat).sum(dim=(-1, -2))
        # Compute tr(X) and tr(Y) as sum of diagonals for 2x2 matrices
        tr_x = x_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        tr_y = d_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        # For n=2, 2n = 4
        return (4 * tr_xy - tr_x * tr_y)[..., None] 
    else:
        # Feature-wise computation
        x_hat = rearrange(x_hat, 'b f n m1 m2 -> b f 1 n m1 m2')
        d_hat = rearrange(d_hat, 'b d n m1 m2 -> b 1 d n m1 m2')
        tr_xy = (x_hat.transpose(-1, -2) * d_hat).sum(dim=(-1, -2))
        tr_x = x_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        tr_y = d_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        kf = 4 * tr_xy - tr_x * tr_y 
        kf = rearrange(kf, 'b f d n -> b (f d) n 1')
        return kf
    
def killingform_gl3(x_hat, d_hat, feature_wise=False):
    """
    x_hat: tensor with last two dimensions 3x3 representing elements in gl(3, R)
    d_hat: tensor with last two dimensions 3x3 representing elements in gl(3, R)
    Bilinear form for gl3 is 6*tr(x_hat d_hat) - tr(x_hat) tr(d_hat)
    """
    # print(x_hat.shape, d_hat.shape) #[1, 3, 100, 3, 3]
    if not feature_wise:
        # Compute tr(XY)
        tr_xy = (x_hat.transpose(-1, -2) * d_hat).sum(dim=(-1, -2))
        # Compute tr(X) and tr(Y) as sum of diagonals
        tr_x = x_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        tr_y = d_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        # For n=3, 2n = 6
        return (6 * tr_xy - tr_x * tr_y)[..., None]  # [B, F, N, 1]
    else:
        # Feature-wise computation
        x_hat = rearrange(x_hat, 'b f n m1 m2 -> b f 1 n m1 m2')
        d_hat = rearrange(d_hat, 'b d n m1 m2 -> b 1 d n m1 m2')
        tr_xy = (x_hat.transpose(-1, -2) * d_hat).sum(dim=(-1, -2))
        tr_x = x_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        tr_y = d_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        kf = 6 * tr_xy - tr_x * tr_y
        kf = rearrange(kf, 'b f d n -> b (f d) n 1')
        return kf
    
def killingform_gl4(x_hat, d_hat, feature_wise=False):
    """
    x_hat: tensor with last two dimensions 4x4 representing elements in gl(4, R)
    d_hat: tensor with last two dimensions 4x4 representing elements in gl(4, R)
    Bilinear form for gl4 is 8*tr(x_hat d_hat) - tr(x_hat) tr(d_hat)
    """
    print("killing form for gl4")
    if not feature_wise:
        # Compute tr(XY)
        tr_xy = (x_hat.transpose(-1, -2) * d_hat).sum(dim=(-1, -2))
        # Compute tr(X) and tr(Y) as sum of diagonals
        tr_x = x_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        tr_y = d_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        # For n=4, 2n = 8
        return (8 * tr_xy - tr_x * tr_y)[..., None]  # [B, F, N, 1]
    else:
        # Feature-wise computation
        x_hat = rearrange(x_hat, 'b f n m1 m2 -> b f 1 n m1 m2')
        d_hat = rearrange(d_hat, 'b d n m1 m2 -> b 1 d n m1 m2')
        tr_xy = (x_hat.transpose(-1, -2) * d_hat).sum(dim=(-1, -2))
        tr_x = x_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        tr_y = d_hat.diagonal(dim1=-2, dim2=-1).sum(-1)
        kf = 8 * tr_xy - tr_x * tr_y
        kf = rearrange(kf, 'b f d n -> b (f d) n 1')
        return kf
    
def killingform_so3(x_hat, d_hat, feature_wise=False):
    """
    x: a tensor of arbitrary shape with the last two dimension of size 3*3
    d: a tensor of arbitrary shape with the last two dimension of size 3*3
    killing form for so3 is tr(x_hat@d_hat)
    """

    if not feature_wise:
        return (x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))[..., None]   # [B,F,N,1]
    else:
        return torch.einsum('...ii', torch.matmul(x_hat,d_hat))[..., None]
    
def killingform_sl3(x_hat, d_hat, feature_wise=False):
    """
    x: a tensor of arbitrary shape with the last two dimension of size 3*3
    d: a tensor of arbitrary shape with the last two dimension of size 3*3
    killing form for sl3 is 6tr(x_hat@d_hat)
    """
    if not feature_wise:
        return 6*(x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))[..., None]   # [B,F,N,1]
    else:
        x_hat = rearrange(x_hat, 'b f n m1 m2 -> b f 1 n m1 m2')
        d_hat = rearrange(d_hat, 'b d n m1 m2 -> b 1 d n m1 m2')
        kf = 6*(x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))
        kf = rearrange(kf, 'b f d n -> b (f d) n 1')
        
        return kf

def killingform_sl4(x_hat, d_hat, feature_wise=False):
    """
    x: a tensor of arbitrary shape with the last two dimension of size 4*4
    d: a tensor of arbitrary shape with the last two dimension of size 4*4
    killing form for sl4 is 8tr(x_hat@d_hat)
    """
    if not feature_wise:
        return 8*(x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))[..., None]   # [B,F,N,1]
    else:
        x_hat = rearrange(x_hat, 'b f n m1 m2 -> b f 1 n m1 m2')
        d_hat = rearrange(d_hat, 'b d n m1 m2 -> b 1 d n m1 m2')
        kf = 8*(x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))
        kf = rearrange(kf, 'b f d n -> b (f d) n 1')
        
        return kf
def killingform_sp4(x_hat, d_hat, feature_wise=False):
    """
    x: a tensor of arbitrary shape with the last two dimension of size 4*4
    d: a tensor of arbitrary shape with the last two dimension of size 4*4
    killing form for sp4 is 6tr(x_hat@d_hat)
    """
    if not feature_wise:
        return 6*(x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))[..., None]   # [B,F,N,1]
    else:
        return 6*(x_hat.transpose(-1, -2)*d_hat).sum(dim=(-1, -2))[..., None]   # [B,F,N,1]

def lie_bracket(x, y):
    return x@y - y@x