# -*- coding:UTF-8 -*-


import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import numpy as np
from conv_util import PointNetSaModule, cost_volume, set_upconv_module, FlowPredictor, Conv1d, BasicBlock, Feature_Gather
from pwclonet_model_utils import *
from deformable_mamba import Deformable_Mamba
from cross_swin_transformer import Cross_BasicLayer
from Fusion_module import GlobalFuser
import numpy as np
import matplotlib.pyplot as plt
from mamba_ssm.modules.mamba_simple import Mamba
visualization = False
resuse_feature = False
pred_composation = True
composation_num = 20
def compute_pose_diff(cumulative_q_pred, cumulative_t_pred, cumulative_q_gt, cumulative_t_gt):
    """
    Compute the quaternion and translation differences from predicted poses to ground truth (GT) poses.

    Parameters:
    - cumulative_q_pred: Predicted cumulative quaternions (batch_size, 1, 4)
    - cumulative_t_pred: Predicted cumulative translations (batch_size, 1, 3)
    - cumulative_q_gt: Ground truth cumulative quaternions (batch_size, 1, 4)
    - cumulative_t_gt: Ground truth cumulative translations (batch_size, 1, 3)

    Returns:
    - q_diff: Quaternion difference (batch_size, 1, 4)
    - t_diff: Translation difference (batch_size, 1, 3)
    """

    def quat_inv(q):
        """ Compute the inverse of a quaternion """
        q_conjugate = q.clone()
        q_conjugate[:, :, 1:] = -q[:, :, 1:]  # Conjugate of the quaternion: negate x, y, z
        q_norm = torch.sum(q**2, dim=-1, keepdim=True)
        q_inv = q_conjugate / q_norm  # Compute the inverse of the quaternion
        return q_inv

    def quat_multiply(q1, q2):
        """ Perform quaternion multiplication """
        w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
        w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]

        w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
        x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
        y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
        z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2

        return torch.stack((w, x, y, z), dim=-1)

    def quat_rotate(q, v):
        """ Rotate a vector using a quaternion """
        q_conjugate = quat_inv(q)
        v_quat = torch.cat((torch.zeros(v.shape[0], v.shape[1], 1).cuda(), v), dim=-1)
        v_rotated = quat_multiply(quat_multiply(q, v_quat), q_conjugate)
        return v_rotated[..., 1:]  # Return only the vector part of the rotated result

    # Compute quaternion difference
    q_pred_inv = quat_inv(cumulative_q_pred)  # Compute the inverse of the predicted quaternion
    q_diff = quat_multiply(cumulative_q_gt, q_pred_inv)  # Compute the quaternion difference

    # Compute translation difference
    cumulative_t_pred_in_gt_frame = quat_rotate(q_diff, cumulative_t_pred)  # Rotate predicted translation into the GT frame
    t_diff = cumulative_t_gt - cumulative_t_pred_in_gt_frame  # Compute translation difference

    return q_diff, t_diff

def inv_q_batch(q):
    """
    Compute the inverse of a batch of quaternions.
    
    Parameters:
    - q: Input quaternion tensor with shape (b, n, 4)

    Returns:
    - q_inv: Inverse of the input quaternions
    """

    # Compute the sum of squares of quaternion components
    q_2 = torch.sum(q * q, dim=-1, keepdim=True) + 1e-10

    # Extract the first component (scalar part) q0
    q0 = q[..., :1]  # Equivalent to torch.index_select(q, -1, torch.LongTensor([0]).cuda())

    # Extract the last three components (vector part) and negate them
    q_ijk = -q[..., 1:]  # Equivalent to torch.index_select(q, -1, torch.LongTensor([1, 2, 3]).cuda())

    # Concatenate q0 and q_ijk
    q_ = torch.cat([q0, q_ijk], dim=-1)

    # Compute the inverse of the quaternion
    q_inv = q_ / q_2

    return q_inv



def q_norm(q):
    return q / (torch.sqrt(torch.sum( q * q, dim=-1, keepdim=True) + 1e-10) + 1e-10)

def get_selected_idx(batch_size, out_H: int, out_W: int, stride_H: int, stride_W: int):
    """According to given stride and output size, return the corresponding selected points

    Args:
        array (tf.Tensor): [any array with shape (B, H, W, 3)]
        stride_H (int): [stride in height]
        stride_W (int): [stride in width]
        out_H (int): [height of output array]
        out_W (int): [width of output array]
    Returns:
        [tf.Tensor]: [shape (B, outh, outw, 3) indices]
    """
    select_h_idx = torch.arange(0, out_H * stride_H, stride_H)
    select_w_idx = torch.arange(0, out_W * stride_W, stride_W)
    height_indices = (torch.reshape(select_h_idx, (1, -1, 1))).expand(batch_size, out_H, out_W)         # b out_H out_W 
    width_indices = (torch.reshape(select_w_idx, (1, 1, -1))).expand(batch_size, out_H, out_W)            # b out_H out_W 
    padding_indices = torch.reshape(torch.arange(batch_size), (-1, 1, 1)).expand(batch_size, out_H, out_W)   # b out_H out_W 

    return padding_indices, height_indices, width_indices


class pwc_model(nn.Module):
    def __init__(self, batch_size, H_input, W_input, is_training, bn_decay=None):
        super(pwc_model, self).__init__()
        self.count = 0
        #####   initialize the parameters (distance  &  stride ) ######
        self.H_input = H_input; self.W_input = W_input

        self.Down_conv_dis = [0.75, 3.0, 6.0, 12.0]
        self.Up_conv_dis = [3.0, 6.0, 9.0]
        self.Cost_volume_dis = [1.0, 2.0, 4.5]

        self.stride_H_list = [4, 2, 2, 1]
        self.stride_W_list = [8, 2, 2, 2]
        self.length = [700, 250, 100, 70]

        self.out_H_list = [math.ceil(self.H_input / self.stride_H_list[0])]
        self.out_W_list = [math.ceil(self.W_input / self.stride_W_list[0])]

        for i in range(1, 4):
            self.out_H_list.append(math.ceil(self.out_H_list[i - 1] / self.stride_H_list[i]))
            self.out_W_list.append(math.ceil(self.out_W_list[i - 1] / self.stride_W_list[i]))  # generate the output shape list


        self.training = is_training
        self.w_x = torch.nn.Parameter(torch.tensor([0.0]), requires_grad=True)
        self.w_q = torch.nn.Parameter(torch.tensor([-2.5]), requires_grad=True)


        self.layer0 = PointNetSaModule(batch_size = batch_size, K_sample = 32, kernel_size = [9, 15], H = self.out_H_list[0], W = self.out_W_list[0], \
                                       stride_H = self.stride_H_list[0], stride_W = self.stride_W_list[0], distance = self.Down_conv_dis[0], in_channels = 3,
                                       mlp = [8, 8, 16], is_training = self.training,
                                       bn_decay = bn_decay)  

        self.layer1 = PointNetSaModule(batch_size = batch_size, K_sample = 32, kernel_size = [7, 11], H = self.out_H_list[1], W = self.out_W_list[1], \
                                       stride_H = self.stride_H_list[1], stride_W = self.stride_W_list[1], distance = self.Down_conv_dis[1],
                                       in_channels = 16,
                                       mlp=[16, 16, 32], is_training=self.training,
                                       bn_decay = bn_decay) 

        self.layer2 = PointNetSaModule(batch_size = batch_size, K_sample = 16, kernel_size = [5, 9], H = self.out_H_list[2], W = self.out_W_list[2], \
                                       stride_H = self.stride_H_list[2], stride_W = self.stride_W_list[2], distance = self.Down_conv_dis[2],
                                       in_channels=32,
                                       mlp=[32, 32, 64], is_training=self.training,
                                       bn_decay=bn_decay)

        self.layer3 = PointNetSaModule(batch_size = batch_size, K_sample = 16, kernel_size = [5, 9], H = self.out_H_list[3], W = self.out_W_list[3], \
                                       stride_H = self.stride_H_list[3], stride_W = self.stride_W_list[3], distance = self.Down_conv_dis[3],
                                       in_channels=64,
                                       mlp=[64, 64, 128], is_training=self.training,
                                       bn_decay=bn_decay)  

        self.laye3_1 = PointNetSaModule(batch_size = batch_size, K_sample = 16, kernel_size = [5, 9], H = self.out_H_list[3], W = self.out_W_list[3], \
                                       stride_H = self.stride_H_list[3], stride_W = self.stride_W_list[3], distance = self.Down_conv_dis[3],
                                       in_channels=64,
                                       mlp=[128, 64, 64], is_training=self.training,
                                       bn_decay=bn_decay)  


        self.cost_volume1 = cost_volume(batch_size = batch_size, kernel_size1 = [3, 5], kernel_size2 = [5, 35] , nsample=4, nsample_q=32, \
                                       H = self.out_H_list[2], W = self.out_W_list[2], \
                                       stride_H = 1, stride_W = 1, distance = self.Cost_volume_dis[2],
                                       in_channels = [64, 64],
                                       mlp1=[128, 64, 64], mlp2=[128, 64], is_training=self.training, bn_decay=bn_decay,
                                       bn=True, pooling='max', knn=True, corr_func='concat')  
                                       
        self.cost_volume2 = cost_volume(batch_size = batch_size, kernel_size1 = [3, 5], kernel_size2 = [5, 15] , nsample=4, nsample_q = 6, \
                                       H = self.out_H_list[2], W = self.out_W_list[2], \
                                       stride_H = 1, stride_W = 1, distance = self.Cost_volume_dis[2],
                                       in_channels = [64, 64],
                                       mlp1=[128, 64, 64], mlp2=[128, 64], is_training=self.training, bn_decay=bn_decay,
                                       bn=True,
                                       pooling='max', knn=True, corr_func='concat')

        self.cost_volume3 = cost_volume(batch_size = batch_size, kernel_size1 = [3, 5], kernel_size2 = [7, 25] , nsample=4, nsample_q = 6, \
                                       H = self.out_H_list[1], W = self.out_W_list[1], \
                                       stride_H = 1, stride_W = 1, distance = self.Cost_volume_dis[1],
                                       in_channels = [32, 32],
                                       mlp1=[128, 64, 64], mlp2=[128, 64], is_training=self.training, bn_decay=bn_decay,
                                       bn=True,
                                       pooling='max', knn=True, corr_func='concat')  


        self.cost_volume4 = cost_volume(batch_size = batch_size, kernel_size1 = [3, 5], kernel_size2 = [11, 41] , nsample=4, nsample_q = 6, \
                                       H = self.out_H_list[0], W = self.out_W_list[0], \
                                       stride_H = 1, stride_W = 1, distance = self.Cost_volume_dis[0],
                                       in_channels = [16, 16],
                                       mlp1=[128, 64, 64], mlp2=[128, 64], is_training=self.training, bn_decay=bn_decay,
                                       bn=True,
                                       pooling='max', knn=True, corr_func='concat') 


        self.flow_predictor0 = FlowPredictor(in_channels=64 * 3, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.flow_predictor1_predict = FlowPredictor(in_channels=64 * 3, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.flow_predictor1_w = FlowPredictor(in_channels=64 * 3, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.flow_predictor2_predict = FlowPredictor(in_channels=64 * 2 + 32, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.flow_predictor2_w = FlowPredictor(in_channels=64 * 2 + 32, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.flow_predictor3_predict = FlowPredictor(in_channels=64 * 2 + 16, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.flow_predictor3_w = FlowPredictor(in_channels=64 * 2 + 16, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  


        self.set_upconv1_w_upsample = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15],
                                            H = self.out_H_list[2], W = self.out_W_list[2],
                                            stride_H = self.stride_H_list[-1], stride_W = self.stride_W_list[-1],
                                            nsample=8, distance = self.Up_conv_dis[2],
                                            in_channels=[64, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True)  

        self.set_upconv1_upsample = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15],
                                            H = self.out_H_list[2], W = self.out_W_list[2],                               
                                            stride_H = self.stride_H_list[-1], stride_W = self.stride_W_list[-1],
                                            nsample=8, distance = self.Up_conv_dis[2],
                                            in_channels=[64, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True)  

        self.set_upconv2_w_upsample = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15], 
                                            H = self.out_H_list[1], W = self.out_W_list[1],
                                            stride_H = self.stride_H_list[-2], stride_W = self.stride_W_list[-2], \
                                            nsample=8, distance = self.Up_conv_dis[1],
                                            in_channels=[32, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True)  

        self.set_upconv2_upsample = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15], 
                                            H = self.out_H_list[1], W = self.out_W_list[1],
                                            stride_H = self.stride_H_list[-2], stride_W = self.stride_W_list[-2], \
                                            nsample=8, distance = self.Up_conv_dis[1],
                                            in_channels=[32, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True)  

        self.set_upconv3_w_upsample = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15], 
                                            H = self.out_H_list[0], W = self.out_W_list[0],
                                            stride_H = self.stride_H_list[-3], stride_W = self.stride_W_list[-3], \
                                            nsample=8, distance = self.Up_conv_dis[0],
                                            in_channels=[16, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True) 

        self.set_upconv3_upsample = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15], 
                                            H = self.out_H_list[0], W = self.out_W_list[0],
                                            stride_H = self.stride_H_list[-3], stride_W = self.stride_W_list[-3], \
                                            nsample=8, distance = self.Up_conv_dis[0],
                                            in_channels=[16, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True)  


        self.conv1_l3 = Conv1d(256+64, 4, use_activation=False)  
        self.conv1_l2 = Conv1d(256, 4, use_activation=False)  
        self.conv1_l1 = Conv1d(256, 4, use_activation=False)  
        self.conv1_l0 = Conv1d(256, 4, use_activation=False)  
        self.conv2_l3 = Conv1d(256+64, 3, use_activation=False)  
        self.conv2_l2 = Conv1d(256, 3, use_activation=False)  
        self.conv2_l1 = Conv1d(256, 3, use_activation=False)  
        self.conv2_l0 = Conv1d(256, 3, use_activation=False)  
        self.conv3_l3 = Conv1d(64, 256, use_activation=False)  
        self.conv3_l2 = Conv1d(64, 256, use_activation=False)  
        self.conv3_l1 = Conv1d(64, 256, use_activation=False)  
        self.conv3_l0 = Conv1d(64, 256, use_activation=False)

        self.conv1_l3_query = Conv1d(256, 4, use_activation=False)  
        self.conv1_l2_query = Conv1d(256, 4, use_activation=False)
        self.conv1_l1_query = Conv1d(256, 4, use_activation=False)
        self.conv1_l0_query = Conv1d(256, 4, use_activation=False)
        self.conv2_l3_query = Conv1d(256, 3, use_activation=False) 
        self.conv2_l2_query = Conv1d(256, 3, use_activation=False)
        self.conv2_l1_query = Conv1d(256, 3, use_activation=False)
        self.conv2_l0_query = Conv1d(256, 3, use_activation=False)
        self.conv3_l3_query = Conv1d(64, 256, use_activation=False)
        self.conv3_l2_query = Conv1d(64, 256, use_activation=False)
        self.conv3_l1_query = Conv1d(64, 256, use_activation=False)
        self.conv3_l0_query = Conv1d(64, 256, use_activation=False)

        self.BasicBlock_0 = BasicBlock(3, 16, stride=2)
        self.BasicBlock_1 = BasicBlock(16, 32, stride=1)
        self.BasicBlock_2 = BasicBlock(32, 64, stride=1)
        # self.BasicBlock_3 = BasicBlock(64, 128, stride=1)

        self.cluster_l0 = Deformable_Mamba(dim=16, out_dim=16, fold_w=8, fold_h=8, heads=1, head_dim=16, return_center=True)
        self.cluster_l1 = Deformable_Mamba(dim=32, out_dim=32, fold_w=4, fold_h=4, heads=1, head_dim=32, return_center=True)
        self.cluster_l2 = Deformable_Mamba(dim=64, out_dim=64, fold_w=2, fold_h=2, heads=1, head_dim=64, return_center=True)


        self.feat_fuserl0 = GlobalFuser(in_channels_2d=16, in_channels_3d=16)
        self.feat_fuserl1 = GlobalFuser(in_channels_2d=32, in_channels_3d=32)
        self.feat_fuserl2 = GlobalFuser(in_channels_2d=64, in_channels_3d=64)
        # self.feat_fuserl3 = GlobalFuser(in_channels_2d=128, in_channels_3d=128)
        self.generate_pos_embedding_l0 = nn.Linear(3, 16)
        self.generate_pos_embedding_l1 = nn.Linear(3, 32)
        self.generate_pos_embedding_l2 = nn.Linear(3, 64)
        self.generate_pos_embedding_l3 = nn.Linear(3, 128)

        self.mlp_q = GatedMLP(4, 128, 64)
        self.mlp_t = GatedMLP(3, 128, 64)
        self.gmlp_q_f = GatedMLP(256, 256, 256)
        self.gmlp_t_f = GatedMLP(256, 256, 256)
        self.mamba_q = nn.Sequential(
                                    nn.LayerNorm(64),
                                    nn.ReLU(),
                                    Mamba(d_model=64, d_state=16, d_conv=4, expand=2),
                                    nn.LayerNorm(64),
        )
        self.mamba_feature_q_l3 = nn.Sequential(
                                    nn.LayerNorm(256),
                                    nn.ReLU(),
                                    Mamba(d_model=256, d_state=16, d_conv=4, expand=2),
                                    nn.LayerNorm(256),
        )
        self.mamba_feature_t_l3 = nn.Sequential(
                                    nn.LayerNorm(256),
                                    nn.ReLU(),
                                    Mamba(d_model=256, d_state=16, d_conv=4, expand=2),
                                    nn.LayerNorm(256),
        )
        self.mamba_t = nn.Sequential(
                                    nn.LayerNorm(64),
                                    nn.ReLU(),
                                    Mamba(d_model=64, d_state=16, d_conv=4, expand=2),
                                    nn.LayerNorm(64),
        )

     
        self.final_predictor_q_dynamic = nn.Sequential(
                                    # nn.LayerNorm(64 + 4 + 4),
                                    # nn.ReLU(),
                                    Dynamic_Trajectory_Decoder(64 + 4 + 4, 4)
        )
        self.final_predictor_t_dynamic = nn.Sequential(
                                    # nn.LayerNorm(64 + 3 + 3),
                                    # nn.ReLU(),
                                    Dynamic_Trajectory_Decoder(64 + 3 + 3, 3)
        )
        self.laye3_1_composa = PointNetSaModule(batch_size = batch_size, K_sample = 16, kernel_size = [5, 9], H = self.out_H_list[3], W = self.out_W_list[3], \
                                       stride_H = self.stride_H_list[3], stride_W = self.stride_W_list[3], distance = self.Down_conv_dis[3],
                                       in_channels=64,
                                       mlp=[128, 64, 64], is_training=self.training,
                                       bn_decay=bn_decay)  
        self.cost_volume1_composa = cost_volume(batch_size = batch_size, kernel_size1 = [3, 5], kernel_size2 = [5, 35] , nsample=4, nsample_q=32, \
                                       H = self.out_H_list[2], W = self.out_W_list[2], \
                                       stride_H = 1, stride_W = 1, distance = self.Cost_volume_dis[2],
                                       in_channels = [64, 64],
                                       mlp1=[128, 64, 64], mlp2=[128, 64], is_training=self.training, bn_decay=bn_decay,
                                       bn=True, pooling='max', knn=True, corr_func='concat')  
        self.cost_volume2_composa = cost_volume(batch_size = batch_size, kernel_size1 = [3, 5], kernel_size2 = [5, 15] , nsample=4, nsample_q = 6, \
                                       H = self.out_H_list[2], W = self.out_W_list[2], \
                                       stride_H = 1, stride_W = 1, distance = self.Cost_volume_dis[2],
                                       in_channels = [64, 64],
                                       mlp1=[128, 64, 64], mlp2=[128, 64], is_training=self.training, bn_decay=bn_decay,
                                       bn=True,
                                       pooling='max', knn=True, corr_func='concat')
        self.set_upconv1_w_upsample_composa = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15],
                                            H = self.out_H_list[2], W = self.out_W_list[2],
                                            stride_H = self.stride_H_list[-1], stride_W = self.stride_W_list[-1],
                                            nsample=8, distance = self.Up_conv_dis[2],
                                            in_channels=[64, 64],
                                            mlp=[128, 64], mlp2=[64], is_training=self.training,
                                            bn_decay=bn_decay, knn=True) 
        self.flow_predictor0_composa = FlowPredictor(in_channels=64 * 3, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay) 
        self.flow_predictor1_predict_composa = FlowPredictor(in_channels=64 * 3, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay) 
        self.flow_predictor1_w_composa = FlowPredictor(in_channels=64 * 3, mlp=[128, 64], is_training=self.training,
                                             bn_decay=bn_decay)  
        self.set_upconv1_upsample_composa = set_upconv_module(batch_size = batch_size, kernel_size = [7, 15],
                                        H = self.out_H_list[2], W = self.out_W_list[2],                               
                                        stride_H = self.stride_H_list[-1], stride_W = self.stride_W_list[-1],
                                        nsample=8, distance = self.Up_conv_dis[2],
                                        in_channels=[64, 64],
                                        mlp=[128, 64], mlp2=[64], is_training=self.training,
                                        bn_decay=bn_decay, knn=True)
        self.conv1_l3_composa = Conv1d(256, 4, use_activation=False)
        self.conv1_l2_composa = Conv1d(256, 4, use_activation=False) 
        self.conv2_l3_composa = Conv1d(256, 3, use_activation=False) 
        self.conv2_l2_composa = Conv1d(256, 3, use_activation=False) 
        self.conv3_l3_composa = Conv1d(64, 256, use_activation=False)
        self.conv3_l2_composa = Conv1d(64, 256, use_activation=False) 
        # freeze layers
        # self.freeze_layers()
        self.init_history(batch_size)

    def freeze_layers(self):
        for name, param in self.named_parameters():
            if "composa" in name:
                continue
            print(name)
            param.requires_grad = False

    def init_history(self, batch_size):
        feature_dim = 256
        max_frames = composation_num + 1 
        self.continue_frames = 0  

        self.l3_points_f1_new_q_history = torch.zeros((batch_size, max_frames, feature_dim)).cuda()
        self.l3_points_f1_new_t_history = torch.zeros((batch_size, max_frames, feature_dim)).cuda()



        zero_quaternion = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32).cuda() 
        small_translation = torch.full((batch_size, max_frames, 3), 1e-15).cuda()  

        self.l0_q_history = zero_quaternion.repeat(batch_size, max_frames, 1)
        self.l0_t_history = small_translation
        self.gt_q_history = zero_quaternion.repeat(batch_size, max_frames, 1)
        self.gt_t_history = small_translation

        self.history_mask = torch.zeros((batch_size, max_frames)).cuda()
        self.his_l2_xyz_proj_f2 = 0
        self.his_l2_points_proj_f2 = 0
        self.his_l2_points_f2 = 0
        self.his_l3_xyz_proj_f2 = 0
        self.his_l3_points_f2 = 0
        self.his_l1_xyz_proj_f2 = 0
        self.his_l1_points_f2 = 0
        self.his_l0_xyz_proj_f2 = 0
        self.his_l0_points_f2 = 0
        self.his_l0_xyz_proj_f1 = torch.zeros((batch_size, composation_num+1, self.out_H_list[0], self.out_W_list[0], 3)).cuda()
        self.his_l0_points_f1 = torch.zeros((batch_size, composation_num+1, self.out_W_list[0]*self.out_H_list[0], 16)).cuda()
        self.his_l2_xyz_proj_f1= torch.zeros((batch_size, composation_num+1, self.out_H_list[2], self.out_W_list[2], 3)).cuda()
        self.his_l2_points_proj_f1 = torch.zeros((batch_size, composation_num+1, self.out_H_list[2], self.out_W_list[2], 64)).cuda()
        self.his_l2_points_f1 = torch.zeros((batch_size, composation_num+1, self.out_W_list[2]*self.out_H_list[2], 64)).cuda()
        self.his_l3_xyz_proj_f1 = torch.zeros((batch_size, composation_num+1, self.out_H_list[3], self.out_W_list[3], 3)).cuda()
        self.his_l3_points_f1 = torch.zeros((batch_size, composation_num+1, self.out_W_list[3]*self.out_H_list[3], 128)).cuda()

    def update_history(self, new_state, history_state):

        unpdate_state = torch.cat([history_state[:, 1:], new_state.clone().detach()], dim=1)
        
        return unpdate_state

    def update_history_mask(self, history_mask):
        update_mask = torch.cat([history_mask[:, 1:], torch.ones((history_mask.shape[0], 1)).cuda()], dim=1)
        return update_mask
    
    def compute_cumulative_gt_pred(self, batch_size):
        for i in range(composation_num):
            if i == 0:
                cumulative_q_pred = q_norm(self.l0_q_history[:,i-composation_num]).unsqueeze(1)
                cumulative_t_pred = self.l0_t_history[:,i-composation_num].unsqueeze(1)
                cumulative_q_gt = self.gt_q_history[:,i-composation_num].unsqueeze(1)
                cumulative_t_gt = self.gt_t_history[:,i-composation_num].unsqueeze(1)
            else:
                q_pred_inv = inv_q(q_norm(self.l0_q_history[:,i-composation_num]).unsqueeze(1), batch_size)
                q_gt_inv = inv_q(self.gt_q_history[:,i-composation_num].unsqueeze(1), batch_size)
                cumulative_t_pred = torch.cat([torch.zeros([batch_size, 1, 1]).cuda(), cumulative_t_pred], dim=-1)
                cumulative_t_pred = mul_q_point(q_norm(self.l0_q_history[:,i-composation_num]).unsqueeze(1), cumulative_t_pred, batch_size)
                cumulative_t_pred = torch.index_select(mul_point_q(cumulative_t_pred, q_pred_inv, batch_size), 2, torch.LongTensor(range(1, 4)).cuda())
                cumulative_q_pred = mul_point_q(q_norm(self.l0_q_history[:,i-composation_num]).unsqueeze(1), cumulative_q_pred, batch_size)
                cumulative_t_pred = cumulative_t_pred + self.l0_t_history[:,i-composation_num].unsqueeze(1)
                cumulative_t_gt = torch.cat([torch.zeros([batch_size, 1, 1]).cuda(), cumulative_t_gt], dim=-1)
                cumulative_t_gt = mul_q_point(self.gt_q_history[:,i-composation_num].unsqueeze(1), cumulative_t_gt, batch_size)
                cumulative_t_gt = torch.index_select(mul_point_q(cumulative_t_gt, q_gt_inv, batch_size), 2, torch.LongTensor(range(1, 4)).cuda())
                cumulative_q_gt = mul_point_q(self.gt_q_history[:,i-composation_num].unsqueeze(1), cumulative_q_gt, batch_size)
                cumulative_t_gt = cumulative_t_gt + self.gt_t_history[:,i-composation_num].unsqueeze(1)
        return cumulative_q_pred, cumulative_t_pred, cumulative_q_gt, cumulative_t_gt

    def update_cumulative_predictions(self, q_diff, t_diff, cumulative_q_pred, cumulative_t_pred, batch_size):
        """
        Updates the cumulative quaternion and translation predictions.

        Args:
            q_diff (torch.Tensor): The quaternion difference, shape [batch_size, 1, 4].
            t_diff (torch.Tensor): The translation difference, shape [batch_size, 1, 3].
            cumulative_q_pred (torch.Tensor): The cumulative quaternion predictions, shape [batch_size, 1, 4].
            cumulative_t_pred (torch.Tensor): The cumulative translation predictions, shape [batch_size, 1, 3].
            batch_size (int): The size of the batch.

        Returns:
            torch.Tensor: Updated cumulative quaternion predictions.
            torch.Tensor: Updated cumulative translation predictions.
        """
        q_diff_inv = inv_q(q_diff, batch_size)

        # Prepare cumulative_t_pred by adding a zero in the last dimension to match quaternion multiplication requirements
        cumulative_t_pred = torch.cat([torch.zeros([batch_size, 1, 1]).cuda(), cumulative_t_pred], dim=-1)

        # Update cumulative_t_pred by quaternion transformations
        cumulative_t_pred = mul_q_point(q_diff, cumulative_t_pred, batch_size)
        cumulative_t_pred = torch.index_select(
            mul_point_q(cumulative_t_pred, q_diff_inv, batch_size),
            2,
            torch.LongTensor(range(1, 4)).cuda()
        )

        # Update cumulative_q_pred by quaternion multiplication
        cumulative_q_pred = mul_point_q(q_diff, cumulative_q_pred, batch_size)

        # Add translation difference
        cumulative_t_pred = cumulative_t_pred + t_diff

        return cumulative_q_pred.squeeze(1), cumulative_t_pred.squeeze(1)


    def process_layers_two_and_three(
        self,
        l2_xyz_proj_f1,
        l2_xyz_proj_f2,
        l2_points_proj_f1,
        l2_points_proj_f2,
        l3_xyz_proj_f1,
        l3_points_f1,
        l2_points_f1,
        batch_size
    ):
        """
        Processes layers 2 and 3 by computing cost volumes, warping poses, re-projecting points,
        and predicting updated quaternions and translations.

        Args:
            self: Reference to the class instance.
            l2_xyz_proj_f1 (torch.Tensor): Projected XYZ coordinates for frame 1 at layer 2.
            l2_xyz_proj_f2 (torch.Tensor): Projected XYZ coordinates for frame 2 at layer 2.
            l2_points_proj_f1 (torch.Tensor): Feature points for frame 1 at layer 2.
            l2_points_proj_f2 (torch.Tensor): Feature points for frame 2 at layer 2.
            l3_xyz_proj_f1 (torch.Tensor): Projected XYZ coordinates for frame 1 at layer 3.
            l3_points_f1 (torch.Tensor): Feature points for frame 1 at layer 3.
            l2_points_f1 (torch.Tensor): Original feature points for frame 1 at layer 2.
            batch_size (int): Batch size.

        Returns:
            torch.Tensor: Updated quaternion predictions at layer 2 (`l2_q`).
            torch.Tensor: Updated translation predictions at layer 2 (`l2_t`).
            torch.Tensor: Updated quaternion predictions at layer 3 (`l3_q`).
            torch.Tensor: Updated translation predictions at layer 3 (`l3_t`).
        """
        # Compute the initial cost volume at layer 2
        l2_cost_volume_origin = self.cost_volume1_composa(
            l2_xyz_proj_f1, l2_xyz_proj_f2, l2_points_proj_f1, l2_points_proj_f2
        )
        # Reshape the cost volume to match the expected dimensions
        l2_cost_volume_origin_proj = torch.reshape(
            l2_cost_volume_origin,
            [batch_size, self.out_H_list[2], self.out_W_list[2], -1]
        )

        # Layer 3 processing
        # Compute the cost volume and its projection for layer 3
        l3_cost_volume, l3_cost_volume_proj = self.laye3_1_composa(
            l2_xyz_proj_f1, l2_cost_volume_origin_proj, l3_xyz_proj_f1
        )
        # Predict flow weights at layer 3
        l3_cost_volume_w = self.flow_predictor0_composa(l3_points_f1, None, l3_cost_volume)
        l3_cost_volume_w_proj = torch.reshape(
            l3_cost_volume_w,
            [batch_size, self.out_H_list[3], self.out_W_list[3], -1]
        )

        # Reshape and create a mask for valid points
        l3_xyz_f1 = torch.reshape(l3_xyz_proj_f1, [batch_size, -1, 3])
        mask_l3 = torch.any(l3_xyz_f1 != 0, dim=-1)

        # Compute the new feature points using softmax with valid mask
        l3_points_f1_new = softmax_valid(
            feature_bnc=l3_cost_volume,
            weight_bnc=l3_cost_volume_w,
            mask_valid=mask_l3
        )  # Shape: [batch_size, 1, C]

        # Pass through convolutional layers and apply dropout
        l3_points_f1_new_big = self.conv3_l3_composa(l3_points_f1_new)
        l3_points_f1_new_q = F.dropout(
            l3_points_f1_new_big, p=0.5, training=self.training
        )
        l3_points_f1_new_t = F.dropout(
            l3_points_f1_new_big, p=0.5, training=self.training
        )

        # Predict coarse quaternion and translation at layer 3
        l3_q_coarse = self.conv1_l3_composa(l3_points_f1_new_q)
        # Normalize the quaternion
        l3_q_coarse = l3_q_coarse / (
            torch.sqrt(torch.sum(l3_q_coarse * l3_q_coarse, dim=-1, keepdim=True) + 1e-10) + 1e-10
        )
        l3_t_coarse = self.conv2_l3_composa(l3_points_f1_new_t)

        # Squeeze to remove the singleton dimension
        l3_q = torch.squeeze(l3_q_coarse, dim=1)
        l3_t = torch.squeeze(l3_t_coarse, dim=1)

        # Layer 2 processing
        # Reshape quaternion and translation for layer 2
        l2_q_coarse = torch.reshape(l3_q, [batch_size, 1, -1])
        l2_t_coarse = torch.reshape(l3_t, [batch_size, 1, -1])
        l2_q_inv = inv_q(l2_q_coarse, batch_size)

        # Warp layer 2 pose
        l2_xyz_f1 = torch.reshape(l2_xyz_proj_f1, [batch_size, -1, 3])
        num_points_l2 = self.out_H_list[2] * self.out_W_list[2]
        # Prepare for quaternion multiplication
        l2_xyz_bnc_q = torch.cat(
            [torch.zeros([batch_size, num_points_l2, 1]).cuda(), l2_xyz_f1],
            dim=-1
        )

        # Apply quaternion transformations
        l2_flow_warped = mul_q_point(l2_q_coarse, l2_xyz_bnc_q, batch_size)
        l2_flow_warped = torch.index_select(
            mul_point_q(l2_flow_warped, l2_q_inv, batch_size),
            2,
            torch.LongTensor(range(1, 4)).cuda()
        ) + l2_t_coarse

        # Create a mask for valid points
        l2_mask = torch.any(l2_xyz_f1 != 0, dim=-1, keepdim=True).to(torch.float32)
        l2_flow_warped = l2_flow_warped * l2_mask

        # Re-project the warped points onto the spherical ring
        l2_xyz_warp_proj_f1, l2_points_warp_proj_f1 = ProjectPC2SphericalRing(
            l2_flow_warped, l2_points_f1, self.out_H_list[2], self.out_W_list[2]
        )
        l2_xyz_warp_f1 = torch.reshape(
            l2_xyz_warp_proj_f1, [batch_size, -1, 3]
        )
        l2_points_warp_f1 = torch.reshape(
            l2_points_warp_proj_f1,
            [batch_size, num_points_l2, -1]
        )

        # Mask for warped valid points
        l2_mask_warped = torch.any(l2_xyz_warp_f1 != 0, dim=-1)

        # Compute the cost volume between warped frame 1 and frame 2 at layer 2
        l2_cost_volume = self.cost_volume2_composa(
            l2_xyz_warp_proj_f1, l2_xyz_proj_f2,
            l2_points_warp_proj_f1, l2_points_proj_f2
        )

        # Upsample cost volumes from layer 3
        l2_cost_volume_w_upsample = self.set_upconv1_w_upsample_composa(
            l2_xyz_warp_proj_f1, l3_xyz_proj_f1,
            l2_points_warp_proj_f1, l3_cost_volume_w_proj
        )
        l2_cost_volume_upsample = self.set_upconv1_upsample_composa(
            l2_xyz_warp_proj_f1, l3_xyz_proj_f1,
            l2_points_warp_proj_f1, l3_cost_volume_proj
        )

        # Predict flow and weights at layer 2
        l2_cost_volume_predict = self.flow_predictor1_predict_composa(
            l2_points_warp_f1, l2_cost_volume_upsample, l2_cost_volume
        )
        l2_cost_volume_w = self.flow_predictor1_w_composa(
            l2_points_warp_f1, l2_cost_volume_w_upsample, l2_cost_volume
        )

        # Reshape the cost volumes for layer 2
        l2_cost_volume_proj = torch.reshape(
            l2_cost_volume_predict,
            [batch_size, self.out_H_list[2], self.out_W_list[2], -1]
        )
        l2_cost_volume_w_proj = torch.reshape(
            l2_cost_volume_w,
            [batch_size, self.out_H_list[2], self.out_W_list[2], -1]
        )

        # Compute the new feature points using softmax with valid mask
        l2_cost_volume_sum = softmax_valid(
            feature_bnc=l2_cost_volume_predict,
            weight_bnc=l2_cost_volume_w,
            mask_valid=l2_mask_warped
        )  # Shape: [batch_size, 1, C]

        # Pass through convolutional layers and apply dropout
        l2_points_f1_new_big = self.conv3_l2_composa(l2_cost_volume_sum)
        l2_points_f1_new_q = F.dropout(
            l2_points_f1_new_big, p=0.5, training=self.training
        )
        l2_points_f1_new_t = F.dropout(
            l2_points_f1_new_big, p=0.5, training=self.training
        )

        # Predict delta quaternion and translation at layer 2
        l2_q_det = self.conv1_l2_composa(l2_points_f1_new_q)
        # Normalize the delta quaternion
        l2_q_det = l2_q_det / (
            torch.sqrt(torch.sum(l2_q_det * l2_q_det, dim=-1, keepdim=True) + 1e-10) + 1e-10
        )
        l2_t_det = self.conv2_l2_composa(l2_points_f1_new_t)

        # Compute the inverse of the delta quaternion
        l2_q_det_inv = inv_q(l2_q_det, batch_size)

        # Transform the coarse translation
        l2_t_coarse_trans = torch.cat(
            [torch.zeros([batch_size, 1, 1]).cuda(), l2_t_coarse], dim=-1
        )
        l2_t_coarse_trans = mul_q_point(
            l2_q_det, l2_t_coarse_trans, batch_size
        )
        l2_t_coarse_trans = torch.index_select(
            mul_point_q(l2_t_coarse_trans, l2_q_det_inv, batch_size),
            2,
            torch.LongTensor(range(1, 4)).cuda()
        )

        # Update quaternion and translation predictions at layer 2
        l2_q = torch.squeeze(
            mul_point_q(l2_q_det, l2_q_coarse, batch_size), dim=1
        )
        l2_t = torch.squeeze(
            l2_t_coarse_trans + l2_t_det, dim=1
        )

        return l2_q, l2_t, l3_q, l3_t

    def update_f2_history(self, l2_xyz_proj_f2, l2_points_proj_f2, l2_points_f2, l3_xyz_proj_f2, l3_points_f2, l1_xyz_proj_f2, l1_points_f2, l0_xyz_proj_f2, l0_points_f2):
        self.his_l2_xyz_proj_f2 = l2_xyz_proj_f2.clone().detach()
        self.his_l2_points_proj_f2 = l2_points_proj_f2.clone().detach()
        self.his_l2_points_f2 = l2_points_f2.clone().detach()
        self.his_l3_xyz_proj_f2 = l3_xyz_proj_f2.clone().detach()
        self.his_l3_points_f2 = l3_points_f2.clone().detach()
        self.his_l1_xyz_proj_f2 = l1_xyz_proj_f2.clone().detach()
        self.his_l1_points_f2 = l1_points_f2.clone().detach()
        self.his_l0_xyz_proj_f2 = l0_xyz_proj_f2.clone().detach()
        self.his_l0_points_f2 = l0_points_f2.clone().detach()                  


    def forward(self, input_xyz_f1, input_xyz_f2, input_img_f1, input_img_f2, input_xy_f1, input_xy_f2, T_gt, T_trans, T_trans_inv, pts_valid_flag_f1, pts_valid_flag_f2, fn2_dir):
        
        if visualization:
            self.training = visualization
        device = input_xyz_f1.device

        start_train = time.time()

        batch_size = input_xyz_f1.shape[0]

        input_points_proj_f1 = torch.zeros(batch_size, self.H_input, self.W_input, 3).cuda().detach()
        input_points_proj_f2 = torch.zeros(batch_size, self.H_input, self.W_input, 3).cuda().detach()


        self.l0_b_idx, self.l0_h_idx, self.l0_w_idx = get_selected_idx( batch_size, self.out_H_list[0], self.out_W_list[0], self.stride_H_list[0], self.stride_W_list[0] )
        self.l1_b_idx, self.l1_h_idx, self.l1_w_idx = get_selected_idx( batch_size, self.out_H_list[1], self.out_W_list[1], self.stride_H_list[1], self.stride_W_list[1] )
        self.l2_b_idx, self.l2_h_idx, self.l2_w_idx = get_selected_idx( batch_size, self.out_H_list[2], self.out_W_list[2], self.stride_H_list[2], self.stride_W_list[2] )
        self.l3_b_idx, self.l3_h_idx, self.l3_w_idx = get_selected_idx( batch_size, self.out_H_list[3], self.out_W_list[3], self.stride_H_list[3], self.stride_W_list[3] )

        torch.cuda.synchronize()
        start_time = time.time()

        aug_frame = np.random.choice([1, 2], size = batch_size, replace = True) # random choose aug frame 1 or 2
        input_xyz_aug_f1, input_xyz_aug_f2, q_gt, t_gt = PreProcess(input_xyz_f1, input_xyz_f2, T_gt, T_trans, T_trans_inv, aug_frame)

        # input_xyz_aug_proj_f1 = ProjectPC2SphericalRing(input_xyz_aug_f1, None, self.H_input, self.W_input)  ## proj func
        # input_xyz_aug_proj_f2 = ProjectPC2SphericalRing(input_xyz_aug_f2, None, self.H_input, self.W_input)
        if self.continue_frames == 0 or not resuse_feature:
            input_xyz_aug_proj_f1, input_flag_proj_f1, input_xy_proj_f1 = ProjectPCflagxy2SphericalRing(input_xyz_aug_f1, pts_valid_flag_f1, input_xy_f1, None, self.H_input,self.W_input)  ## proj func
        input_xyz_aug_proj_f2, input_flag_proj_f2, input_xy_proj_f2 = ProjectPCflagxy2SphericalRing(input_xyz_aug_f2, pts_valid_flag_f2, input_xy_f2, None, self.H_input, self.W_input)
        
    
        # print('data_pre_process_proj+aug: ', time.time() - start_time)


        ####  the l0 select bn3 xyz
        if self.continue_frames == 0 or not resuse_feature:
            l0_xyz_proj_f1 = input_xyz_aug_proj_f1[self.l0_b_idx.cuda().long(), self.l0_h_idx.cuda().long(), self.l0_w_idx.cuda().long(), :]
            l0_flag_proj_f1 = input_flag_proj_f1[self.l0_b_idx.cuda().long(), self.l0_h_idx.cuda().long(), self.l0_w_idx.cuda().long(), :]
            l0_xy_proj_f1 = input_xy_proj_f1[self.l0_b_idx.cuda().long(), self.l0_h_idx.cuda().long(), self.l0_w_idx.cuda().long(), :]

        l0_xyz_proj_f2 = input_xyz_aug_proj_f2[self.l0_b_idx.cuda().long(), self.l0_h_idx.cuda().long(), self.l0_w_idx.cuda().long(), :]
        l0_flag_proj_f2 = input_flag_proj_f2[self.l0_b_idx.cuda().long(), self.l0_h_idx.cuda().long(), self.l0_w_idx.cuda().long(), :]
        l0_xy_proj_f2 = input_xy_proj_f2[self.l0_b_idx.cuda().long(), self.l0_h_idx.cuda().long(), self.l0_w_idx.cuda().long(), :]

        ####  the l1 select bn3 xyz
        if self.continue_frames == 0 or not resuse_feature:
            l1_xyz_proj_f1 = l0_xyz_proj_f1[self.l1_b_idx.cuda().long(), self.l1_h_idx.cuda().long(), self.l1_w_idx.cuda().long(), :]
            l1_flag_proj_f1 = l0_flag_proj_f1[self.l1_b_idx.cuda().long(), self.l1_h_idx.cuda().long(), self.l1_w_idx.cuda().long(), :]
            l1_xy_proj_f1 = l0_xy_proj_f1[self.l1_b_idx.cuda().long(), self.l1_h_idx.cuda().long(), self.l1_w_idx.cuda().long(), :]

        l1_xyz_proj_f2 = l0_xyz_proj_f2[self.l1_b_idx.cuda().long(), self.l1_h_idx.cuda().long(), self.l1_w_idx.cuda().long(), :]
        l1_flag_proj_f2 = l0_flag_proj_f2[self.l1_b_idx.cuda().long(), self.l1_h_idx.cuda().long(), self.l1_w_idx.cuda().long(), :]
        l1_xy_proj_f2 = l0_xy_proj_f2[self.l1_b_idx.cuda().long(), self.l1_h_idx.cuda().long(), self.l1_w_idx.cuda().long(), :]

        ####  the l2 select bn3 xyz
        if self.continue_frames == 0 or not resuse_feature:
            l2_xyz_proj_f1 = l1_xyz_proj_f1[self.l2_b_idx.cuda().long(), self.l2_h_idx.cuda().long(), self.l2_w_idx.cuda().long(), :]
            l2_flag_proj_f1 = l1_flag_proj_f1[self.l2_b_idx.cuda().long(), self.l2_h_idx.cuda().long(), self.l2_w_idx.cuda().long(), :]
            
            l2_xy_proj_f1 = l1_xy_proj_f1[self.l2_b_idx.cuda().long(), self.l2_h_idx.cuda().long(), self.l2_w_idx.cuda().long(), :]
            
        
        l2_xyz_proj_f2 = l1_xyz_proj_f2[self.l2_b_idx.cuda().long(), self.l2_h_idx.cuda().long(), self.l2_w_idx.cuda().long(), :]
        
        l2_flag_proj_f2 = l1_flag_proj_f2[self.l2_b_idx.cuda().long(), self.l2_h_idx.cuda().long(), self.l2_w_idx.cuda().long(), :]
        l2_xy_proj_f2 = l1_xy_proj_f2[self.l2_b_idx.cuda().long(), self.l2_h_idx.cuda().long(), self.l2_w_idx.cuda().long(), :]

        ####  the l3 select bn3 xyz
        if self.continue_frames == 0 or not resuse_feature:
            l3_xyz_proj_f1 = l2_xyz_proj_f1[self.l3_b_idx.cuda().long(), self.l3_h_idx.cuda().long(), self.l3_w_idx.cuda().long(), :]
            l3_flag_proj_f1 = l2_flag_proj_f1[self.l3_b_idx.cuda().long(), self.l3_h_idx.cuda().long(), self.l3_w_idx.cuda().long(), :]
            l3_xy_proj_f1 = l2_xy_proj_f1[self.l3_b_idx.cuda().long(), self.l3_h_idx.cuda().long(), self.l3_w_idx.cuda().long(), :]

        l3_xyz_proj_f2 = l2_xyz_proj_f2[self.l3_b_idx.cuda().long(), self.l3_h_idx.cuda().long(), self.l3_w_idx.cuda().long(), :]
        l3_flag_proj_f2 = l2_flag_proj_f2[self.l3_b_idx.cuda().long(), self.l3_h_idx.cuda().long(), self.l3_w_idx.cuda().long(), :]
        l3_xy_proj_f2 = l2_xy_proj_f2[self.l3_b_idx.cuda().long(), self.l3_h_idx.cuda().long(), self.l3_w_idx.cuda().long(), :]

        # print('pre_process: ', time.time() - start_train)

        set_conv_start = time.time()
        if self.continue_frames == 0 or not resuse_feature:
            l0_mask_f1 = l0_flag_proj_f1.squeeze(-1).to(torch.bool)
            l1_mask_f1 = l1_flag_proj_f1.squeeze(-1).to(torch.bool)
            l2_mask_f1 = l2_flag_proj_f1.squeeze(-1).to(torch.bool)
            l3_mask_f1 = l3_flag_proj_f1.squeeze(-1).to(torch.bool)


        l0_mask_f2 = l0_flag_proj_f2.squeeze(-1).to(torch.bool)
        l1_mask_f2 = l1_flag_proj_f2.squeeze(-1).to(torch.bool)
        l2_mask_f2 = l2_flag_proj_f2.squeeze(-1).to(torch.bool)
        l3_mask_f2 = l3_flag_proj_f2.squeeze(-1).to(torch.bool)

        ## flame 1
        # layer 0
        if self.continue_frames == 0 or not resuse_feature:
            l0_points_f1, l0_points_proj_f1 = self.layer0(input_xyz_aug_proj_f1, input_points_proj_f1, l0_xyz_proj_f1)
            image0_f1 = self.BasicBlock_0(input_img_f1.to(torch.float32))
            b = l0_points_proj_f1.shape[0]
            h = l0_points_proj_f1.shape[1]
            w = l0_points_proj_f1.shape[2]
            c = l0_points_proj_f1.shape[3]

            l0_xy_cor_f1 = torch.zeros(b, self.length[0], 2)
            l0_proj_cor_f1 = torch.zeros(b, self.length[0], c)
            start_time = time.time()
            for batch in range(batch_size):
                l0_batch_xy_f1 = l0_xy_proj_f1[batch:batch + 1, :, :][l0_mask_f1[batch:batch + 1, :, :]][:, :2]  # [N, 2]
                l0_xy_cor_f1[batch, :l0_batch_xy_f1.shape[0], :] = l0_batch_xy_f1  # [B, N, 2]
                l0_batch_proj_f1 = l0_points_proj_f1[batch:batch + 1, :, :][l0_mask_f1[batch:batch + 1, :, :]][:,
                                :]  # [N, C]
                l0_proj_cor_f1[batch, :l0_batch_proj_f1.shape[0], :] = l0_batch_proj_f1  # [B, N, C]

            mask_valid_xy_f1 = torch.any(l0_xy_cor_f1 != 0, dim=-1, keepdim=True).squeeze(-1).to(torch.bool)
            l0_points_origin_f1 = l0_points_proj_f1.clone().cuda(device) 
            init_query_0 = torch.zeros(b, l0_xy_cor_f1.shape[1], l0_points_proj_f1.shape[-1]).cuda(device)
            init_query_0[mask_valid_xy_f1] = l0_points_proj_f1[l0_mask_f1]
            init_qery_anchor_0 = torch.zeros(b, l0_xy_cor_f1.shape[1], 3).cuda(device)
            init_qery_anchor_0[mask_valid_xy_f1] = l0_xyz_proj_f1[l0_mask_f1]
            init_query_pos_embed_0 = self.generate_pos_embedding_l0(init_qery_anchor_0)
            init_query_0 = init_query_0 + init_query_pos_embed_0
            l0_img_gather_feature_full_f1 = self.cluster_l0(init_query_0, l0_xy_cor_f1, image0_f1, device)  # [B, C, 1, N]
            l0_img_gather_feature_full_f1 = l0_img_gather_feature_full_f1.squeeze(2).permute(0, 2, 1)  # [B, N, C]
            l0_points_proj_f1[l0_mask_f1] = l0_img_gather_feature_full_f1[mask_valid_xy_f1]
            l0_points_proj_f1 = self.feat_fuserl0(l0_points_proj_f1, l0_points_origin_f1).permute(0, 2, 3, 1)  # [B,h,w,c]

            # layer 1
            l1_points_f1, l1_points_proj_f1 = self.layer1(l0_xyz_proj_f1, l0_points_proj_f1, l1_xyz_proj_f1)
            image1_f1 = self.BasicBlock_1(image0_f1.to(torch.float32))
            b = l1_points_proj_f1.shape[0]
            h = l1_points_proj_f1.shape[1]
            w = l1_points_proj_f1.shape[2]
            c = l1_points_proj_f1.shape[3]

            l1_xy_cor_f1 = torch.zeros(b, self.length[1], 2)
            l1_proj_cor_f1 = torch.zeros(b, self.length[1], c)
            for batch in range(batch_size):
                l1_batch_xy_f1 = l1_xy_proj_f1[batch:batch + 1, :, :][l1_mask_f1[batch:batch + 1, :, :]][:, :2]  # [N, 2]
                l1_xy_cor_f1[batch, :l1_batch_xy_f1.shape[0], :] = l1_batch_xy_f1  # [B, N, 2]
                l1_batch_proj_f1 = l1_points_proj_f1[batch:batch + 1, :, :][l1_mask_f1[batch:batch + 1, :, :]][:,
                                :]  # [N, C]
                l1_proj_cor_f1[batch, :l1_batch_proj_f1.shape[0], :] = l1_batch_proj_f1  # [B, N, C]

            mask_valid_xy_f1 = torch.any(l1_xy_cor_f1 != 0, dim=-1, keepdim=True).squeeze(-1).to(torch.bool)

            l1_points_origin_f1 = l1_points_proj_f1.clone().cuda(device)
            init_query_1 = torch.zeros(b, l1_xy_cor_f1.shape[1], l1_points_proj_f1.shape[-1]).cuda(device)
            init_query_1[mask_valid_xy_f1] = l1_points_proj_f1[l1_mask_f1]
            init_qery_anchor_1 = torch.zeros(b, l1_xy_cor_f1.shape[1], 3).cuda(device)
            init_qery_anchor_1[mask_valid_xy_f1] = l1_xyz_proj_f1[l1_mask_f1]
            init_query_pos_embed_1 = self.generate_pos_embedding_l1(init_qery_anchor_1)
            init_query_1 = init_query_1 + init_query_pos_embed_1
            l1_img_gather_feature_full_f1 = self.cluster_l1(init_query_1, l1_xy_cor_f1, image1_f1, device)
            l1_img_gather_feature_full_f1 = l1_img_gather_feature_full_f1.permute(0, 2, 3, 1).squeeze(1)
            l1_points_proj_f1[l1_mask_f1] = l1_img_gather_feature_full_f1[mask_valid_xy_f1]
            l1_points_proj_f1 = self.feat_fuserl1(l1_points_proj_f1, l1_points_origin_f1).permute(0, 2, 3, 1)  # [B,h,w,c]

            # layer 2
            l2_points_f1, l2_points_proj_f1 = self.layer2(l1_xyz_proj_f1, l1_points_proj_f1, l2_xyz_proj_f1)
            image2_f1 = self.BasicBlock_2(image1_f1.to(torch.float32))
            b = l2_points_proj_f1.shape[0]
            h = l2_points_proj_f1.shape[1]
            w = l2_points_proj_f1.shape[2]
            c = l2_points_proj_f1.shape[3]

            l2_xy_cor_f1 = torch.zeros(b, self.length[2], 2)
            l2_proj_cor_f1 = torch.zeros(b, self.length[2], c)
            for batch in range(batch_size):
                l2_batch_xy_f1 = l2_xy_proj_f1[batch:batch + 1, :, :][l2_mask_f1[batch:batch + 1, :, :]][:, :2]  # [N, 2]
                l2_xy_cor_f1[batch, :l2_batch_xy_f1.shape[0], :] = l2_batch_xy_f1  # [B, N, 2]
                l2_batch_proj_f1 = l2_points_proj_f1[batch:batch + 1, :, :][l2_mask_f1[batch:batch + 1, :, :]][:,
                                :]  # [N, C]
                l2_proj_cor_f1[batch, :l2_batch_proj_f1.shape[0], :] = l2_batch_proj_f1  # [B, N, C]

            mask_valid_xy_f1 = torch.any(l2_xy_cor_f1 != 0, dim=-1, keepdim=True).squeeze(-1).to(torch.bool)
            l2_points_origin_f1 = l2_points_proj_f1.clone().cuda(device)
            init_query_2 = torch.zeros(b, l2_xy_cor_f1.shape[1], l2_points_proj_f1.shape[-1]).cuda(device)
            init_query_2[mask_valid_xy_f1] = l2_points_proj_f1[l2_mask_f1]
            init_qery_anchor_2 = torch.zeros(b, l2_xy_cor_f1.shape[1], 3).cuda(device)
            init_qery_anchor_2[mask_valid_xy_f1] = l2_xyz_proj_f1[l2_mask_f1]
            init_query_pos_embed_2 = self.generate_pos_embedding_l2(init_qery_anchor_2)
            init_query_2 = init_query_2 + init_query_pos_embed_2

            l2_img_gather_feature_full_f1 = self.cluster_l2(init_query_2, l2_xy_cor_f1, image2_f1, device)
            l2_img_gather_feature_full_f1 = l2_img_gather_feature_full_f1.permute(0, 2, 3, 1).squeeze(1)
            l2_points_proj_f1[l2_mask_f1] = l2_img_gather_feature_full_f1[mask_valid_xy_f1]
            l2_points_proj_f1 = self.feat_fuserl2(l2_points_proj_f1, l2_points_origin_f1).permute(0, 2, 3,
                                                                                                1)  # [B, C, 1, N]

            # layer 3
            l3_points_f1, l3_points_proj_f1 = self.layer3(l2_xyz_proj_f1, l2_points_proj_f1, l3_xyz_proj_f1)
        elif resuse_feature:
            l2_xyz_proj_f1, l2_points_proj_f1, l2_points_f1, l3_xyz_proj_f1, l3_points_f1, l1_xyz_proj_f1, l1_points_f1, l0_xyz_proj_f1, l0_points_f1 =  self.his_l2_xyz_proj_f2.clone(), self.his_l2_points_proj_f2.clone(), self.his_l2_points_f2.clone(), self.his_l3_xyz_proj_f2.clone(), self.his_l3_points_f2.clone(), self.his_l1_xyz_proj_f2.clone(), self.his_l1_points_f2.clone(), self.his_l0_xyz_proj_f2.clone(), self.his_l0_points_f2.clone()


        ## flame 2
        # layer 0
        l0_points_f2, l0_points_proj_f2 = self.layer0(input_xyz_aug_proj_f2, input_points_proj_f2, l0_xyz_proj_f2)
        image0_f2 = self.BasicBlock_0(input_img_f2.to(torch.float32))
        b = l0_points_proj_f2.shape[0]
        h = l0_points_proj_f2.shape[1]
        w = l0_points_proj_f2.shape[2]
        c = l0_points_proj_f2.shape[3]

        l0_xy_cor_f2 = torch.zeros(b, self.length[0], 2)
        l0_proj_cor_f2 = torch.zeros(b, self.length[0], c)
        for batch in range(batch_size):
            l0_batch_xy_f2 = l0_xy_proj_f2[batch:batch + 1, :, :][l0_mask_f2[batch:batch + 1, :, :]][:, :2]  # [N, 2]
            l0_xy_cor_f2[batch, :l0_batch_xy_f2.shape[0], :] = l0_batch_xy_f2  # [B, N, 2]
            l0_batch_proj_f2 = l0_points_proj_f2[batch:batch + 1, :, :][l0_mask_f2[batch:batch + 1, :, :]][:,
                               :]  # [N, C]
            l0_proj_cor_f2[batch, :l0_batch_proj_f2.shape[0], :] = l0_batch_proj_f2  # [B, N, C]

        mask_valid_xy_f2 = torch.any(l0_xy_cor_f2 != 0, dim=-1, keepdim=True).squeeze(-1).to(torch.bool)

        l0_points_origin_f2 = l0_points_proj_f2.clone().cuda(device)
        init_query_02 = torch.zeros(b, l0_xy_cor_f2.shape[1], l0_points_proj_f2.shape[-1]).cuda(device)
        init_query_02[mask_valid_xy_f2] = l0_points_proj_f2[l0_mask_f2]
        init_qery_anchor_02 = torch.zeros(b, l0_xy_cor_f2.shape[1], 3).cuda(device)
        init_qery_anchor_02[mask_valid_xy_f2] = l0_xyz_proj_f2[l0_mask_f2]
        init_query_pos_embed_02 = self.generate_pos_embedding_l0(init_qery_anchor_02)
        init_query_02 = init_query_02 + init_query_pos_embed_02
        l0_img_gather_feature_full_f2 = self.cluster_l0(init_query_02, l0_xy_cor_f2, image0_f2, device)
        l0_img_gather_feature_full_f2 = l0_img_gather_feature_full_f2.permute(0, 2, 3, 1).squeeze(1)
        l0_points_proj_f2[l0_mask_f2] = l0_img_gather_feature_full_f2[mask_valid_xy_f2]
        l0_points_proj_f2 = self.feat_fuserl0(l0_points_proj_f2, l0_points_origin_f2).permute(0, 2, 3,
                                                                                              1)  # [B, C, 1, N]

        # layer 1
        l1_points_f2, l1_points_proj_f2 = self.layer1(l0_xyz_proj_f2, l0_points_proj_f2, l1_xyz_proj_f2)
        image1_f2 = self.BasicBlock_1(image0_f2.to(torch.float32))
        b = l1_points_proj_f2.shape[0]
        h = l1_points_proj_f2.shape[1]
        w = l1_points_proj_f2.shape[2]
        c = l1_points_proj_f2.shape[3]

        l1_xy_cor_f2 = torch.zeros(b, self.length[1], 2)
        l1_proj_cor_f2 = torch.zeros(b, self.length[1], c)
        for batch in range(batch_size):
            l1_batch_xy_f2 = l1_xy_proj_f2[batch:batch + 1, :, :][l1_mask_f2[batch:batch + 1, :, :]][:, :2]  # [N, 2]
            l1_xy_cor_f2[batch, :l1_batch_xy_f2.shape[0], :] = l1_batch_xy_f2  # [B, N, 2]
            l1_batch_proj_f2 = l1_points_proj_f2[batch:batch + 1, :, :][l1_mask_f2[batch:batch + 1, :, :]][:,
                               :]  # [N, C]
            l1_proj_cor_f2[batch, :l1_batch_proj_f2.shape[0], :] = l1_batch_proj_f2  # [B, N, C]

        mask_valid_xy_f2 = torch.any(l1_xy_cor_f2 != 0, dim=-1, keepdim=True).squeeze(-1).to(torch.bool)

        l1_points_origin_f2 = l1_points_proj_f2.clone().cuda(device)
        init_query_12 = torch.zeros(b, l1_xy_cor_f2.shape[1], l1_points_proj_f2.shape[-1]).cuda(device)
        init_query_12[mask_valid_xy_f2] = l1_points_proj_f2[l1_mask_f2]
        init_qery_anchor_12 = torch.zeros(b, l1_xy_cor_f2.shape[1], 3).cuda(device)
        init_qery_anchor_12[mask_valid_xy_f2] = l1_xyz_proj_f2[l1_mask_f2]
        init_query_pos_embed_12 = self.generate_pos_embedding_l1(init_qery_anchor_12)
        init_query_12 = init_query_12 + init_query_pos_embed_12

        l1_img_gather_feature_full_f2 = self.cluster_l1(init_query_12, l1_xy_cor_f2, image1_f2, device)
        l1_img_gather_feature_full_f2 = l1_img_gather_feature_full_f2.permute(0, 2, 3, 1).squeeze(1)
        l1_points_proj_f2[l1_mask_f2] = l1_img_gather_feature_full_f2[mask_valid_xy_f2]
        l1_points_proj_f2 = self.feat_fuserl1(l1_points_proj_f2, l1_points_origin_f2).permute(0, 2, 3,
                                                                                              1)  # [B, C, 1, N]

        # layer 2
        l2_points_f2, l2_points_proj_f2 = self.layer2(l1_xyz_proj_f2, l1_points_proj_f2, l2_xyz_proj_f2)
        image2_f2 = self.BasicBlock_2(image1_f2.to(torch.float32))
        b = l2_points_proj_f2.shape[0]
        h = l2_points_proj_f2.shape[1]
        w = l2_points_proj_f2.shape[2]
        c = l2_points_proj_f2.shape[3]

        l2_xy_cor_f2 = torch.zeros(b, self.length[2], 2)
        l2_proj_cor_f2 = torch.zeros(b, self.length[2], c)
        for batch in range(batch_size):
            l2_batch_xy_f2 = l2_xy_proj_f2[batch:batch + 1, :, :][l2_mask_f2[batch:batch + 1, :, :]][:, :2]  # [N, 2]
            l2_xy_cor_f2[batch, :l2_batch_xy_f2.shape[0], :] = l2_batch_xy_f2  # [B, N, 2]
            l2_batch_proj_f2 = l2_points_proj_f2[batch:batch + 1, :, :][l2_mask_f2[batch:batch + 1, :, :]][:,
                               :]  # [N, C]
            l2_proj_cor_f2[batch, :l2_batch_proj_f2.shape[0], :] = l2_batch_proj_f2  # [B, N, C]

        mask_valid_xy_f2 = torch.any(l2_xy_cor_f2 != 0, dim=-1, keepdim=True).squeeze(-1).to(torch.bool)
        l2_points_origin_f2 = l2_points_proj_f2.clone().cuda(device)
        init_query_22 = torch.zeros(b, l2_xy_cor_f2.shape[1], l2_points_proj_f2.shape[-1]).cuda(device)
        init_query_22[mask_valid_xy_f2] = l2_points_proj_f2[l2_mask_f2]
        init_qery_anchor_22 = torch.zeros(b, l2_xy_cor_f2.shape[1], 3).cuda(device)
        init_qery_anchor_22[mask_valid_xy_f2] = l2_xyz_proj_f2[l2_mask_f2]
        init_query_pos_embed_22 = self.generate_pos_embedding_l2(init_qery_anchor_22)
        init_query_22 = init_query_22 + init_query_pos_embed_22

        l2_img_gather_feature_full_f2 = self.cluster_l2(init_query_22, l2_xy_cor_f2, image2_f2, device)
        l2_img_gather_feature_full_f2 = l2_img_gather_feature_full_f2.permute(0, 2, 3, 1).squeeze(1)
        l2_points_proj_f2[l2_mask_f2] = l2_img_gather_feature_full_f2[mask_valid_xy_f2]
        l2_points_proj_f2 = self.feat_fuserl2(l2_points_proj_f2, l2_points_origin_f2).permute(0, 2, 3, 1)

        # layer 3
        l3_points_f2, l3_points_proj_f2 = self.layer3(l2_xyz_proj_f2, l2_points_proj_f2, l3_xyz_proj_f2)
        self.update_f2_history(l2_xyz_proj_f2, l2_points_proj_f2, l2_points_f2, l3_xyz_proj_f2, l3_points_f2, l1_xyz_proj_f2, l1_points_f2, l0_xyz_proj_f2, l0_points_f2)
        self.his_l0_xyz_proj_f1 = self.update_history(l0_xyz_proj_f1.unsqueeze(1), self.his_l0_xyz_proj_f1)
        self.his_l0_points_f1 = self.update_history(l0_points_f1.unsqueeze(1), self.his_l0_points_f1)
        self.his_l0_xyz_proj_f1 = self.update_history(l0_xyz_proj_f1.unsqueeze(1), self.his_l0_xyz_proj_f1)
        self.his_l0_points_f1 = self.update_history(l0_points_f1.unsqueeze(1), self.his_l0_points_f1)
        self.his_l2_xyz_proj_f1= self.update_history(l2_xyz_proj_f1.unsqueeze(1), self.his_l2_xyz_proj_f1)
        self.his_l2_points_proj_f1 = self.update_history(l2_points_proj_f1.unsqueeze(1), self.his_l2_points_proj_f1)
        self.his_l2_points_f1 = self.update_history(l2_points_f1.unsqueeze(1), self.his_l2_points_f1)
        self.his_l3_xyz_proj_f1 = self.update_history(l3_xyz_proj_f1.unsqueeze(1), self.his_l3_xyz_proj_f1)
        self.his_l3_points_f1 = self.update_history(l3_points_f1.unsqueeze(1), self.his_l3_points_f1)
        


        l2_cost_volume_origin = self.cost_volume1(l2_xyz_proj_f1, l2_xyz_proj_f2, l2_points_proj_f1, l2_points_proj_f2)
        l2_cost_volume_origin_proj = torch.reshape(l2_cost_volume_origin,  [batch_size, self.out_H_list[2], self.out_W_list[2], -1])

        # Layer 3 ##################
        
        l3_cost_volume, l3_cost_volume_proj = self.laye3_1(l2_xyz_proj_f1, l2_cost_volume_origin_proj, l3_xyz_proj_f1)
        l3_cost_volume_w = self.flow_predictor0(l3_points_f1, None, l3_cost_volume)
        l3_cost_volume_w_proj = torch.reshape(l3_cost_volume_w, [batch_size, self.out_H_list[3], self.out_W_list[3], -1])


        l3_xyz_f1 = torch.reshape(l3_xyz_proj_f1, [batch_size, -1, 3])
        mask_l3 = torch.any(l3_xyz_f1 != 0, dim = -1)
        # if self.training:
        l3_points_f1_new, l3_query_feature_final = softmax_valid(feature_bnc = l3_cost_volume, weight_bnc = l3_cost_volume_w, mask_valid = mask_l3)  # B 1 C

        l3_points_f1_new_big = self.conv3_l3(l3_points_f1_new)
        if self.training:
            l3_points_f1_new_big_query = self.conv3_l3_query(l3_query_feature_final)
        l3_points_f1_new_q = F.dropout(l3_points_f1_new_big, p = 0.5, training = self.training)
        l3_points_f1_new_t = F.dropout(l3_points_f1_new_big, p = 0.5, training = self.training)
        if self.continue_frames > 0:
            l3_points_f1_new_q = torch.max(self.mamba_feature_q_l3(self.gmlp_q_f(torch.cat([self.l3_points_f1_new_q_history[:,1:], l3_points_f1_new_q], dim=1))), dim = 1)[0].unsqueeze(1)
        q_his_encode = self.mlp_q(self.l0_q_history)
        t_his_encode = self.mlp_t(self.l0_t_history)
        t_l0_his_embed = torch.max(self.mamba_t(t_his_encode), dim = 1)[0].unsqueeze(1)
        q_l0_his_embed = torch.max(self.mamba_q(q_his_encode), dim = 1)[0].unsqueeze(1)
        if self.continue_frames > 0:
            l3_points_f1_new_t = torch.max(self.mamba_feature_t_l3(self.gmlp_t_f(torch.cat([self.l3_points_f1_new_t_history[:,1:], l3_points_f1_new_t], dim=1))), dim = 1)[0].unsqueeze(1)
        self.l3_points_f1_new_q_history = self.update_history(l3_points_f1_new_q, self.l3_points_f1_new_q_history)
        self.l3_points_f1_new_t_history = self.update_history(l3_points_f1_new_t, self.l3_points_f1_new_t_history)
        l3_points_f1_new_q = torch.cat([l3_points_f1_new_q, q_l0_his_embed], dim = -1)
        l3_points_f1_new_t = torch.cat([l3_points_f1_new_t, t_l0_his_embed], dim = -1)

        l3_q_coarse = self.conv1_l3(l3_points_f1_new_q)
        l3_q_coarse = l3_q_coarse / (torch.sqrt(torch.sum(l3_q_coarse * l3_q_coarse, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        if self.training:
            l3_q_coarse_query = self.conv1_l3_query(l3_points_f1_new_big_query)
            self.l3_q_coarse_query = l3_q_coarse_query / (torch.sqrt(torch.sum(l3_q_coarse_query * l3_q_coarse_query, dim=-1, keepdim=True) + 1e-10) + 1e-10)

        l3_t_coarse = self.conv2_l3(l3_points_f1_new_t)
        if self.training:
            self.l3_t_coarse_query = self.conv2_l3_query(l3_points_f1_new_big_query)

        l3_q = torch.squeeze(l3_q_coarse, dim=1)
        l3_t = torch.squeeze(l3_t_coarse, dim=1)

        ################ layer 2 #################

        l2_q_coarse = torch.reshape(l3_q, [batch_size, 1, -1])
        l2_t_coarse = torch.reshape(l3_t, [batch_size, 1, -1])
        l2_q_inv = inv_q(l2_q_coarse, batch_size)

        ### warp layer2 pose

        l2_xyz_f1 = torch.reshape(l2_xyz_proj_f1, [batch_size, -1, 3])
        l2_xyz_bnc_q = torch.cat([torch.zeros([batch_size, self.out_H_list[2] * self.out_W_list[2], 1]).cuda(), l2_xyz_f1], dim=-1)

        l2_flow_warped = mul_q_point(l2_q_coarse, l2_xyz_bnc_q, batch_size)
        l2_flow_warped = torch.index_select(mul_point_q(l2_flow_warped, l2_q_inv, batch_size), 2, torch.LongTensor(range(1, 4)).cuda()) + l2_t_coarse

        l2_mask = torch.any(l2_xyz_f1 !=0, dim = -1, keepdim = True).to(torch.float32)
        l2_flow_warped = l2_flow_warped * l2_mask


        l2_xyz_warp_proj_f1, l2_points_warp_proj_f1 = ProjectPC2SphericalRing(l2_flow_warped, l2_points_f1, self.out_H_list[2], self.out_W_list[2])  # 
        l2_xyz_warp_f1 = torch.reshape(l2_xyz_warp_proj_f1, [batch_size, -1, 3])
        l2_points_warp_f1 = torch.reshape(l2_points_warp_proj_f1, [batch_size, self.out_H_list[2] * self.out_W_list[2], -1])

        l2_mask_warped = torch.any(l2_xyz_warp_f1 !=0, dim = -1, keepdim = False)


        # get the cost volume of warped layer3 flow and the points of frame2
        l2_cost_volume = self.cost_volume2(l2_xyz_warp_proj_f1, l2_xyz_proj_f2, l2_points_warp_proj_f1, l2_points_proj_f2)

        l2_cost_volume_w_upsample = self.set_upconv1_w_upsample(l2_xyz_warp_proj_f1, l3_xyz_proj_f1, l2_points_warp_proj_f1, l3_cost_volume_w_proj)
        l2_cost_volume_upsample = self.set_upconv1_upsample(l2_xyz_warp_proj_f1, l3_xyz_proj_f1, l2_points_warp_proj_f1, l3_cost_volume_proj)
        
        l2_cost_volume_predict = self.flow_predictor1_predict(l2_points_warp_f1, l2_cost_volume_upsample, l2_cost_volume)
        l2_cost_volume_w = self.flow_predictor1_w(l2_points_warp_f1, l2_cost_volume_w_upsample, l2_cost_volume)

        l2_cost_volume_proj = torch.reshape(l2_cost_volume_predict, [batch_size, self.out_H_list[2], self.out_W_list[2], -1])
        l2_cost_volume_w_proj = torch.reshape(l2_cost_volume_w, [batch_size, self.out_H_list[2], self.out_W_list[2], -1])
        # if self.training:
        l2_cost_volume_sum, l2_query_feature_final = softmax_valid(feature_bnc = l2_cost_volume_predict, weight_bnc = l2_cost_volume_w, mask_valid = l2_mask_warped)  # B 1 C


        l2_points_f1_new_big = self.conv3_l2(l2_cost_volume_sum)
        if self.training:
            l2_points_f1_new_big_query = self.conv3_l2_query(l2_query_feature_final)
        l2_points_f1_new_q = F.dropout(l2_points_f1_new_big, p = 0.5, training = self.training)
        l2_points_f1_new_t = F.dropout(l2_points_f1_new_big, p = 0.5, training = self.training)
        if self.training:
            l2_points_f1_new_q_query = F.dropout(l2_points_f1_new_big_query, p = 0.5, training = self.training)
            l2_points_f1_new_t_query = F.dropout(l2_points_f1_new_big_query, p = 0.5, training = self.training)

        l2_q_det = self.conv1_l2(l2_points_f1_new_q)
        l2_q_det = l2_q_det / (torch.sqrt(torch.sum(l2_q_det * l2_q_det, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        if self.training:
            l2_q_det_query = self.conv1_l2_query(l2_points_f1_new_q_query)
            l2_q_det_query = l2_q_det_query / (torch.sqrt(torch.sum(l2_q_det_query * l2_q_det_query, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        l2_t_det = self.conv2_l2(l2_points_f1_new_t)
        if self.training:
            l2_t_det_query = self.conv2_l2_query(l2_points_f1_new_t_query)


        l2_q_det_inv = inv_q(l2_q_det, batch_size)

        l2_t_coarse_trans = torch.cat([torch.zeros([batch_size, 1, 1]).cuda(), l2_t_coarse], dim=-1)
        l2_t_coarse_trans = mul_q_point(l2_q_det, l2_t_coarse_trans, batch_size)
        l2_t_coarse_trans = torch.index_select(mul_point_q(l2_t_coarse_trans, l2_q_det_inv, batch_size), 2,
                                                  torch.LongTensor(range(1, 4)).cuda())

        l2_q = torch.squeeze(mul_point_q(l2_q_det, l2_q_coarse, batch_size), dim=1)
        l2_t = torch.squeeze(l2_t_coarse_trans + l2_t_det, dim=1)
        if self.training:
            self.l2_q_query = mul_point_q(l2_q_det_query, l2_q_coarse, batch_size)
            self.l2_t_query = l2_t_coarse_trans + l2_t_det_query



        ############# layer1
        start_l1_refine = time.time()

        l1_q_coarse = torch.reshape(l2_q, [batch_size, 1, -1])
        l1_t_coarse = torch.reshape(l2_t, [batch_size, 1, -1])
        l1_q_inv = inv_q(l1_q_coarse, batch_size)

        ############# warp layer2 pose

        l1_xyz_f1 = torch.reshape(l1_xyz_proj_f1, [batch_size, -1, 3])
        l1_xyz_bnc_q = torch.cat([torch.zeros([batch_size, self.out_H_list[1] * self.out_W_list[1], 1]).cuda(), l1_xyz_f1], dim=-1)

        l1_flow_warped = mul_q_point(l1_q_coarse, l1_xyz_bnc_q, batch_size)
        l1_flow_warped = torch.index_select(mul_point_q(l1_flow_warped, l1_q_inv, batch_size), 2, torch.LongTensor(range(1, 4)).cuda()) + l1_t_coarse

        l1_mask = torch.any(l1_xyz_f1 !=0, dim = -1, keepdim = True).to(torch.float32)
        l1_flow_warped = l1_flow_warped * l1_mask


        ########## re-project

        l1_xyz_warp_proj_f1, l1_points_warp_proj_f1 = ProjectPC2SphericalRing(l1_flow_warped, l1_points_f1, self.out_H_list[1], self.out_W_list[1])  # 
        l1_xyz_warp_f1 = torch.reshape(l1_xyz_warp_proj_f1, [batch_size, -1, 3])
        l1_points_warp_f1 = torch.reshape(l1_points_warp_proj_f1, [batch_size, self.out_H_list[1] * self.out_W_list[1], -1])

        l1_mask_warped = torch.any(l1_xyz_warp_f1 !=0, dim = -1, keepdim = False)


        # get the cost volume of warped layer3 flow and the points of frame2
        l1_cost_volume = self.cost_volume3(l1_xyz_warp_proj_f1, l1_xyz_proj_f2, l1_points_warp_proj_f1, l1_points_proj_f2)

        l1_cost_volume_w_upsample = self.set_upconv2_w_upsample(l1_xyz_warp_proj_f1, l2_xyz_warp_proj_f1, l1_points_warp_proj_f1, l2_cost_volume_w_proj)
        l1_cost_volume_upsample = self.set_upconv2_upsample(l1_xyz_warp_proj_f1, l2_xyz_warp_proj_f1, l1_points_warp_proj_f1, l2_cost_volume_proj)
        
        l1_cost_volume_predict = self.flow_predictor2_predict(l1_points_warp_f1, l1_cost_volume_upsample, l1_cost_volume)
        l1_cost_volume_w = self.flow_predictor2_w(l1_points_warp_f1, l1_cost_volume_w_upsample, l1_cost_volume)

        l1_cost_volume_proj = torch.reshape(l1_cost_volume_predict, [batch_size, self.out_H_list[1], self.out_W_list[1], -1])
        l1_cost_volume_w_proj = torch.reshape(l1_cost_volume_w, [batch_size, self.out_H_list[1], self.out_W_list[1], -1])

        # if self.training:
        l1_cost_volume_sum, l1_query_feature_final = softmax_valid(feature_bnc = l1_cost_volume_predict, weight_bnc = l1_cost_volume_w, mask_valid = l1_mask_warped)  # B 1 C

        l1_points_f1_new_big = self.conv3_l1(l1_cost_volume_sum)
        if self.training:
            l1_points_f1_new_big_query = self.conv3_l1_query(l1_query_feature_final)
        l1_points_f1_new_q = F.dropout(l1_points_f1_new_big, p = 0.5, training = self.training)
        l1_points_f1_new_t = F.dropout(l1_points_f1_new_big, p = 0.5, training = self.training)
        if self.training:
            l1_points_f1_new_q_query = F.dropout(l1_points_f1_new_big_query, p = 0.5, training = self.training)
            l1_points_f1_new_t_query = F.dropout(l1_points_f1_new_big_query, p = 0.5, training = self.training)

        l1_q_det = self.conv1_l1(l1_points_f1_new_q)
        l1_q_det = l1_q_det / (torch.sqrt(torch.sum(l1_q_det * l1_q_det, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        if self.training:
            l1_q_det_query = self.conv1_l1_query(l1_points_f1_new_q_query)
            l1_q_det_query = l1_q_det_query / (torch.sqrt(torch.sum(l1_q_det_query * l1_q_det_query, dim=-1, keepdim=True) + 1e-10) + 1e-10)

        l1_t_det = self.conv2_l1(l1_points_f1_new_t)
        if self.training:
            l1_t_det_query = self.conv2_l1_query(l1_points_f1_new_t_query)

        l1_q_det_inv = inv_q(l1_q_det, batch_size)

        l1_t_coarse_trans = torch.cat([torch.zeros([batch_size, 1, 1]).cuda(), l1_t_coarse], dim=-1)
        l1_t_coarse_trans = mul_q_point(l1_q_det, l1_t_coarse_trans, batch_size)

        l1_t_coarse_trans = torch.index_select(mul_point_q(l1_t_coarse_trans, l1_q_det_inv, batch_size), 2,
                                               torch.LongTensor(range(1, 4)).cuda())

        l1_q = torch.squeeze(mul_point_q(l1_q_det, l1_q_coarse, batch_size), dim=1)
        l1_t = torch.squeeze(l1_t_coarse_trans + l1_t_det, dim=1)
        if self.training:
            self.l1_q_query = mul_point_q(l1_q_det_query, l1_q_coarse, batch_size)
            self.l1_t_query = l1_t_coarse_trans + l1_t_det_query


        # print('l1_refine_time--------', time.time() - start_l1_refine)

        ################# layer0

        # start_l0_refine = time.time()

        l0_q_coarse = torch.reshape(l1_q, [batch_size, 1, -1])
        l0_t_coarse = torch.reshape(l1_t, [batch_size, 1, -1])

        l0_q_inv = inv_q(l0_q_coarse, batch_size)

        ############# warp layer2 pose

        l0_xyz_f1 = torch.reshape(l0_xyz_proj_f1, [batch_size, -1, 3])
        l0_xyz_bnc_q = torch.cat([torch.zeros([batch_size, self.out_H_list[0] * self.out_W_list[0], 1]).cuda(), l0_xyz_f1], dim=-1)

        l0_flow_warped = mul_q_point(l0_q_coarse, l0_xyz_bnc_q, batch_size)
        l0_flow_warped = torch.index_select(mul_point_q(l0_flow_warped, l0_q_inv, batch_size), 2, torch.LongTensor(range(1, 4)).cuda()) + l0_t_coarse

        l0_mask = torch.any(l0_xyz_f1 !=0, dim = -1, keepdim = True).to(torch.float32)
        l0_flow_warped = l0_flow_warped * l0_mask

        ########## re-project

        l0_xyz_warp_proj_f1, l0_points_warp_proj_f1 = ProjectPC2SphericalRing(l0_flow_warped, l0_points_f1, self.out_H_list[0], self.out_W_list[0])  # 
        l0_xyz_warp_f1 = torch.reshape(l0_xyz_warp_proj_f1, [batch_size, -1, 3])
        l0_points_warp_f1 = torch.reshape(l0_points_warp_proj_f1, [batch_size, self.out_H_list[0] * self.out_W_list[0], -1])

        l0_mask_warped = torch.any(l0_xyz_warp_f1 !=0, dim = -1, keepdim = False)


        # get the cost volume of warped layer3 flow and the points of frame2
        l0_cost_volume = self.cost_volume4(l0_xyz_warp_proj_f1, l0_xyz_proj_f2, l0_points_warp_proj_f1, l0_points_proj_f2)

        l0_cost_volume_w_upsample = self.set_upconv3_w_upsample(l0_xyz_warp_proj_f1, l1_xyz_warp_proj_f1, l0_points_warp_proj_f1, l1_cost_volume_w_proj)
        l0_cost_volume_upsample = self.set_upconv3_upsample(l0_xyz_warp_proj_f1, l1_xyz_warp_proj_f1, l0_points_warp_proj_f1, l1_cost_volume_proj)
        
        l0_cost_volume_predict = self.flow_predictor3_predict(l0_points_warp_f1, l0_cost_volume_upsample, l0_cost_volume)
        l0_cost_volume_w = self.flow_predictor3_w(l0_points_warp_f1, l0_cost_volume_w_upsample, l0_cost_volume)

        l0_cost_volume_sum, l0_query_feature_final = softmax_valid(feature_bnc = l0_cost_volume_predict, weight_bnc = l0_cost_volume_w, mask_valid = l0_mask_warped)  # B 1 C

        l0_points_f1_new_big = self.conv3_l0(l0_cost_volume_sum)

        l0_points_f1_new_q = F.dropout(l0_points_f1_new_big, p = 0.5, training = self.training)
        l0_points_f1_new_t = F.dropout(l0_points_f1_new_big, p = 0.5, training = self.training)

        l0_q_det = self.conv1_l0(l0_points_f1_new_q)
        l0_q_det = l0_q_det / (torch.sqrt(torch.sum(l0_q_det * l0_q_det, dim=-1, keepdim=True) + 1e-10) + 1e-10)

        l0_t_det = self.conv2_l0(l0_points_f1_new_t)

        l0_q_det_inv = inv_q(l0_q_det, batch_size)
        
        l0_t_coarse_trans = torch.cat([torch.zeros([batch_size, 1, 1]).cuda(), l0_t_coarse], dim=-1)
        l0_t_coarse_trans = mul_q_point(l0_q_det, l0_t_coarse_trans, batch_size)
        l0_t_coarse_trans = torch.index_select(mul_point_q(l0_t_coarse_trans, l0_q_det_inv, batch_size), 2,
                                               torch.LongTensor(range(1, 4)).cuda())

        l0_q_first = torch.squeeze(mul_point_q(l0_q_det, l0_q_coarse, batch_size), dim=1)
        l0_t_first = torch.squeeze(l0_t_coarse_trans + l0_t_det, dim=1)

        l0_q = self.final_predictor_q_dynamic(torch.cat([l0_q_first.unsqueeze(1), self.l0_q_history[:,-1:], q_l0_his_embed], dim = -1)).squeeze(1)
        l0_t = self.final_predictor_t_dynamic(torch.cat([l0_t_first.unsqueeze(1), self.l0_t_history[:,-1:], t_l0_his_embed], dim = -1)).squeeze(1)
        l0_q = q_norm(l0_q)

        self.l0_q_history = self.update_history(l0_q.unsqueeze(1), self.l0_q_history)
        self.l0_t_history = self.update_history(l0_t.unsqueeze(1), self.l0_t_history)
        self.history_mask = self.update_history_mask(self.history_mask)
        self.gt_q_history = self.update_history(q_gt.unsqueeze(1), self.gt_q_history)
        self.gt_t_history = self.update_history(t_gt.squeeze(-1).unsqueeze(1), self.gt_t_history)


        if self.continue_frames >= composation_num +1  and pred_composation:
            
            if not self.training and self.continue_frames % composation_num +1 == 0:
                cumulative_q_pred, cumulative_t_pred, cumulative_q_gt, cumulative_t_gt = self.compute_cumulative_gt_pred(batch_size)
                self.q_diff, self.t_diff = compute_pose_diff(cumulative_q_pred, cumulative_t_pred, cumulative_q_gt, cumulative_t_gt) 
                self.l2_q_composa, self.l2_t_composa, self.l3_q_composa, self.l3_t_composa = self.process_layers_two_and_three(
                                        self.his_l2_xyz_proj_f1[:,-composation_num-1].contiguous(),
                                        l2_xyz_proj_f2,
                                        self.his_l2_points_proj_f1[:,-composation_num].contiguous(),
                                        l2_points_proj_f2,
                                        self.his_l3_xyz_proj_f1[:,-composation_num].contiguous(),
                                        self.his_l3_points_f1[:,-composation_num].contiguous(),
                                        self.his_l2_points_f1[:,-composation_num].contiguous(),
                                        batch_size
                                    )

                l0_q, l0_t = self.update_cumulative_predictions(self.l2_q_composa.unsqueeze(1), self.l2_t_composa.unsqueeze(1),  q_norm(self.l2_q_composa).unsqueeze(1), self.l2_t_composa.unsqueeze(1), batch_size)
                self.composa_loss = 0
            else:
                cumulative_q_pred, cumulative_t_pred, cumulative_q_gt, cumulative_t_gt = self.compute_cumulative_gt_pred(batch_size)
                self.q_diff, self.t_diff = compute_pose_diff(cumulative_q_pred, cumulative_t_pred, cumulative_q_gt, cumulative_t_gt) 
                self.l2_q_composa, self.l2_t_composa, self.l3_q_composa, self.l3_t_composa = self.process_layers_two_and_three(
                                        self.his_l2_xyz_proj_f1[:,-composation_num].contiguous(),
                                        l2_xyz_proj_f2,
                                        self.his_l2_points_proj_f1[:,-composation_num].contiguous(),
                                        l2_points_proj_f2,
                                        self.his_l3_xyz_proj_f1[:,-composation_num].contiguous(),
                                        self.his_l3_points_f1[:,-composation_num].contiguous(),
                                        self.his_l2_points_f1[:,-composation_num].contiguous(),
                                        batch_size
                                    )
                self.composa_loss = self.get_loss_composa(self.l2_q_composa, self.l2_t_composa, self.l3_q_composa, self.l3_t_composa, self.q_diff, self.t_diff, self.w_x, self.w_q)
                self.composa_loss = self.composa_loss.mean()
        else:
            self.composa_loss = 0



        #l0_q = l1_q
        #l0_t = l1_t
        l0_q_first_norm = l0_q_first / (torch.sqrt(torch.sum(l0_q_first * l0_q_first, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        l0_q_norm = l0_q / (torch.sqrt(torch.sum(l0_q * l0_q, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        l1_q_norm = l1_q / (torch.sqrt(torch.sum(l1_q * l1_q, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        l2_q_norm = l2_q / (torch.sqrt(torch.sum(l2_q * l2_q, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        l3_q_norm = l3_q / (torch.sqrt(torch.sum(l3_q * l3_q, dim=-1, keepdim=True) + 1e-10) + 1e-10)

        # print('l0_refime_time: ---------', time.time() - start_l0_refine)
        if visualization:
            l2_origin_xy_f1 = l2_xy_proj_f1.reshape(-1,2)
            l1_origin_xy_f1 = l1_xy_proj_f1.reshape(-1,2)
            l3_origin_xy_f1 = l3_xy_proj_f1.reshape(-1,2)
            self.l1_topk_indices = torch.squeeze(self.l1_topk_indices)
            self.l2_topk_indices = torch.squeeze(self.l2_topk_indices)
            self.l3_topk_indices = torch.squeeze(self.l3_topk_indices)
            points_l1 = l1_origin_xy_f1[self.l1_topk_indices]
            points_l2 = l2_origin_xy_f1[self.l2_topk_indices]
            points_l3 = l3_origin_xy_f1[self.l3_topk_indices]
            points_all = torch.cat([points_l1, points_l2, points_l3], dim = 0)
            self.plot_points(points_all, fn2_dir)


        if self.training:
            self.query_loss = self.get_loss_query(self.l1_q_query, self.l1_t_query, self.l2_q_query, self.l2_t_query, self.l3_q_coarse_query, self.l3_t_coarse_query, q_gt, t_gt, self.w_x, self.w_q)   
        else:
            self.query_loss = 0
        self.continue_frames += 1
        return l0_q_norm, l0_t, l1_q_norm, l1_t, l2_q_norm, l2_t, l3_q_norm, l3_t, l1_xyz_f1, q_gt, t_gt, self.w_x, self.w_q, l0_q_first, l0_t_first, self.composa_loss, self.query_loss


    def get_loss_composa(self, l2_q, l2_t, l3_q, l3_t, qq_gt, t_gt, w_x, w_q):

        t_gt = torch.squeeze(t_gt)

        l2_q_norm = l2_q / (torch.sqrt(torch.sum(l2_q * l2_q, -1, keepdim=True) + 1e-10) + 1e-10)
        l2_loss_q = torch.mean(torch.sqrt(torch.sum((qq_gt - l2_q_norm) * (qq_gt - l2_q_norm), -1, keepdim=True) + 1e-10))
        l2_loss_x = torch.mean(torch.sqrt((l2_t - t_gt) * (l2_t - t_gt) + 1e-10))
        l2_loss = l2_loss_x * torch.exp(-w_x) + w_x + l2_loss_q * torch.exp(-w_q) + w_q

        l3_q_norm = l3_q / (torch.sqrt(torch.sum(l3_q * l3_q, -1, keepdim=True) + 1e-10) + 1e-10)
        l3_loss_q = torch.mean(torch.sqrt(torch.sum((qq_gt - l3_q_norm) * (qq_gt - l3_q_norm), -1, keepdim=True) + 1e-10))
        l3_loss_x = torch.mean(torch.sqrt((l3_t - t_gt) * (l3_t - t_gt) + 1e-10))
        l3_loss = l3_loss_x * torch.exp(-w_x) + w_x + l3_loss_q * torch.exp(-w_q) + w_q

    def compute_loss_for_top_k(self, l_q, l_t, qq_gt, t_gt, w_x, w_q, k):
        """
        Compute the average loss for the top-k samples closest to the ground truth along the n-dimension.
        """
        # Normalize quaternions
        l_q_norm = l_q / (torch.sqrt(torch.sum(l_q * l_q, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        
        # Compute quaternion loss and position loss
        loss_q = torch.sqrt(torch.sum((qq_gt - l_q_norm) ** 2, dim=-1) + 1e-10)  # shape: (b, n)
        loss_x = torch.sqrt(torch.sum((l_t - t_gt) ** 2, dim=-1) + 1e-10)  # shape: (b, n)
        
        # Sum the quaternion loss and position loss
        total_loss = loss_q + loss_x  # shape: (b, n)

        # Select the top-k samples closest to the ground truth
        _, topk_indices = torch.topk(total_loss, k, dim=1, largest=False)  # Get the indices of the smallest k losses
        
        # Gather the top-k samples and compute their average loss
        topk_loss_q = torch.gather(loss_q, 1, topk_indices)
        topk_loss_x = torch.gather(loss_x, 1, topk_indices)
        
        # Compute the mean loss for the top-k samples
        mean_loss_q = torch.mean(topk_loss_q)
        mean_loss_x = torch.mean(topk_loss_x)

        # Compute the total mean loss
        mean_loss = mean_loss_x * torch.exp(-w_x) + w_x + mean_loss_q * torch.exp(-w_q) + w_q

        return mean_loss

    def get_loss_query(self, l1_q, l1_t, l2_q, l2_t, l3_q, l3_t, qq_gt, t_gt, w_x, w_q, k=100):
        t_gt = torch.squeeze(t_gt)

        # Compute the average loss for l0, l1, l2, and l3
        l1_loss = self.compute_loss_for_top_k(l1_q, l1_t, qq_gt.unsqueeze(1), t_gt.unsqueeze(1), w_x, w_q, 100)
        l2_loss = self.compute_loss_for_top_k(l2_q, l2_t, qq_gt.unsqueeze(1), t_gt.unsqueeze(1), w_x, w_q, 100)
        l3_loss = self.compute_loss_for_top_k(l3_q, l3_t, qq_gt.unsqueeze(1), t_gt.unsqueeze(1), w_x, w_q, 100)

        # Compute the total loss using weighted coefficients
        loss_sum = 1.6 * l3_loss + 1.6 * l2_loss + 1.6 * l1_loss

        return loss_sum


def get_loss(l0_q, l0_t, l1_q, l1_t, l2_q, l2_t, l3_q, l3_t, qq_gt, t_gt, w_x, w_q, l0_q_first, l0_t_first):

    t_gt = torch.squeeze(t_gt)

    l0_q_first_norm = l0_q_first / (torch.sqrt(torch.sum(l0_q_first * l0_q_first, dim=-1, keepdim=True) + 1e-10) + 1e-10)
    loss_l0_q_first = torch.mean(torch.sqrt(torch.sum((qq_gt - l0_q_first_norm) * (qq_gt - l0_q_first_norm), dim=-1, keepdim=True) + 1e-10))
    loss_l0_t_first = torch.mean(torch.sqrt((l0_t_first - t_gt) * (l0_t_first - t_gt) + 1e-10))
    l0_loss_first = loss_l0_t_first * torch.exp(-w_x) + w_x + loss_l0_q_first * torch.exp(-w_q) + w_q

    l0_q_norm = l0_q / (torch.sqrt(torch.sum(l0_q * l0_q, dim=-1, keepdim=True) + 1e-10) + 1e-10)
    l0_loss_q = torch.mean(torch.sqrt(torch.sum((qq_gt - l0_q_norm) * (qq_gt - l0_q_norm), dim=-1, keepdim=True) + 1e-10))
    l0_loss_x = torch.mean(torch.sqrt((l0_t - t_gt) * (l0_t - t_gt) + 1e-10))
    l0_loss = l0_loss_x * torch.exp(-w_x) + w_x + l0_loss_q * torch.exp(-w_q) + w_q

    l1_q_norm = l1_q / (torch.sqrt(torch.sum(l1_q * l1_q, -1, keepdim=True) + 1e-10) + 1e-10)
    l1_loss_q = torch.mean( torch.sqrt(torch.sum((qq_gt - l1_q_norm) * (qq_gt - l1_q_norm), -1, keepdim=True) + 1e-10))
    l1_loss_x = torch.mean(torch.sqrt((l1_t - t_gt) * (l1_t - t_gt) + 1e-10))
    l1_loss = l1_loss_x * torch.exp(-w_x) + w_x + l1_loss_q * torch.exp(-w_q) + w_q

    l2_q_norm = l2_q / (torch.sqrt(torch.sum(l2_q * l2_q, -1, keepdim=True) + 1e-10) + 1e-10)
    l2_loss_q = torch.mean(torch.sqrt(torch.sum((qq_gt - l2_q_norm) * (qq_gt - l2_q_norm), -1, keepdim=True) + 1e-10))
    l2_loss_x = torch.mean(torch.sqrt((l2_t - t_gt) * (l2_t - t_gt) + 1e-10))
    l2_loss = l2_loss_x * torch.exp(-w_x) + w_x + l2_loss_q * torch.exp(-w_q) + w_q

    l3_q_norm = l3_q / (torch.sqrt(torch.sum(l3_q * l3_q, -1, keepdim=True) + 1e-10) + 1e-10)
    l3_loss_q = torch.mean(torch.sqrt(torch.sum((qq_gt - l3_q_norm) * (qq_gt - l3_q_norm), -1, keepdim=True) + 1e-10))
    l3_loss_x = torch.mean(torch.sqrt((l3_t - t_gt) * (l3_t - t_gt) + 1e-10))
    l3_loss = l3_loss_x * torch.exp(-w_x) + w_x + l3_loss_q * torch.exp(-w_q) + w_q

    loss_sum = 1.6 * l3_loss + 0.8 * l2_loss + 0.8 * l1_loss + 1.6 * l0_loss + 0.8 * l0_loss_first

    return loss_sum



