import torch
import os
import numpy as np
import open3d as o3d
from tqdm import tqdm
from dgl.geometry import farthest_point_sampler

from gnn.model import DynamicsPredictor, DynamicsPredictorMy
from data.gripper_dataset_2 import construct_edges_from_states
from data.utils import fps_rad_idx_torch
from render.utils import interpolate_motions, relations_to_matrix
from render.phystwin_LBS import interpolate_motions as inter_motion
from render.phystwin_LBS import knn_weights_new, get_topk_indices

import glob

from src.gnn.model import DynamicsPredictorMyMultiLayer


class DynamicsModule:
    
    def __init__(self, config, epoch, device):
        self.device = device
        train_config = config['train_config']
        model_config = config['model_config']
        if epoch == 'latest':
            checkpoint_dir = os.path.join(train_config['out_dir'], 'checkpoints', 'latest.pth')
        else:
            checkpoint_dir = os.path.join(train_config['out_dir'], 'checkpoints', 'model_{}.pth'.format(epoch))
        self.model = self.load_model(train_config, model_config, checkpoint_dir, self.device)
        self.n_his = train_config['n_his']
        self.dist_thresh = train_config['dist_thresh']

        dataset_config = config['dataset_config']['datasets'][0]
        self.max_nobj = dataset_config['max_nobj']
        self.adj_thresh = (dataset_config['adj_radius_range'][0] + dataset_config['adj_radius_range'][1]) / 2
        self.fps_radius = (dataset_config['fps_radius_range'][0] + dataset_config['fps_radius_range'][1]) / 2
        self.topk = dataset_config['topk']
        self.connect_all = dataset_config['connect_all']

    def load_model(self, train_config, model_config, checkpoint_dir, device):
        model_config['n_his'] = train_config['n_his']
        model = DynamicsPredictorMyMultiLayer(model_config, device)
        model.to(device)
        model.eval()
        model.load_state_dict(torch.load(checkpoint_dir))
        return model
    
    def downsample_vertices(self, xyz):  # (n, 3)
        particle_tensor = xyz[None, ...].detach().cpu()
        fps_idx_1 = farthest_point_sampler(particle_tensor, self.max_nobj, start_idx=0)[0]
        downsampled_particle = particle_tensor[0, fps_idx_1, :]
        _, fps_idx_2 = fps_rad_idx_torch(downsampled_particle, self.fps_radius)
        fps_idx = fps_idx_1[fps_idx_2]
        xyz = xyz[fps_idx]
        return xyz, fps_idx

    @torch.no_grad
    def rollout_demo(self, xyz_0, rgb_0, quat_0, opa_0, eef_xyz, n_steps, inlier_idx_all):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
        fps_all_pos_history = fps_all_pos[None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, n_particles, 3)

        eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
        eef_pos = eef_xyz[0]  # (1, 3)

        particle_pos_0, _ = self.downsample_vertices(fps_all_pos.clone())

        # results to store
        quat = quat_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 4)
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        rgb = rgb_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        opa = opa_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 1)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps, n_bones, 3)
        eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        eef_delta = torch.zeros(1, 3).to(device)

        key_point = []

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     quat[i] = quat[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     rgb[i] = rgb[i - 1].clone()
            #     opa[i] = opa[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     #TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue

            eef_pos_this_step = eef_xyz[i]
            eef_delta = eef_pos_this_step - eef_pos

            particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history[:, fps_idx]
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + 1, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + 1, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + 1, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + 1), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            eef_mask[:, nobj] = 1

            obj_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            #TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)

            Rr, Rs = construct_edges_from_states(states[0, -1], self.adj_thresh,
                            mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk, connect_all=self.connect_all)
            Rr = Rr[None]
            Rs = Rs[None]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)

            #TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step


            # interpolate all_pos and particle_pos
            all_pos, all_rot, _ = interpolate_motions(
                bones=particle_pos,
                motions=pred_state[0] - particle_pos,
                relations=relations_to_matrix(Rr, Rs)[:nobj, :nobj],
                xyz=all_pos,
                quat=quat[i - 1].to(device),
            )
            fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)

            quat[i] = all_rot.cpu()
            xyz[i] = all_pos.cpu()
            rgb[i] = rgb[i - 1].clone()
            opa[i] = opa[i - 1].clone()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()
        print('rollout demo')
        return xyz, rgb, quat, opa, xyz_bones, eef

    @torch.no_grad()
    def rollout_gripper_pred(self, xyz_0, Es, control_velocity, control_mask, n_steps, xyz_gt):
        """
        使用基于粒子的动力学模型进行滚动预测

        参数:
            xyz_0: 初始粒子位置 (n_particles, 3)
            Es: 粒子的弹性模量 (n_particles)
            control_velocity: 控制点速度序列 (n_steps-1, 3)
            control_mask: 控制点掩码 (n_particles)
            n_steps: 要预测的步数（不包括初始状态）
            xyz_gt: 真实粒子位置序列 (n_steps+1, n_particles, 3) - 用于调试

        返回:
            包含预测结果的字典
        """
        model = self.model
        device = self.device

        # 1. 采样关键粒子
        fps_idx = farthest_point_sampler(xyz_0.cpu()[None], self.max_nobj, start_idx=0)[0]
        sampled_pos = xyz_0[fps_idx]  # 选取关键粒子位置
        sampled_Es = Es[fps_idx]  # 选取关键粒子的弹性模量

        # 2. 初始化状态缓冲区
        state_history = sampled_pos[None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, max_nobj, 3)
        collider_z = state_history[:, :, -1]  # Z坐标用于碰撞检测
        control_mask = control_mask.to(torch.bool)

        # 3. 创建控制点掩码（在关键粒子中）
        sampled_control_mask = control_mask[fps_idx]  # 只保留被选中的关键粒子

        # 4. 准备存储结果
        pred_positions = torch.zeros(n_steps + 1, self.max_nobj, 3, device=device)
        pred_positions[0] = sampled_pos
        pred_velocities = []

        for i in tqdm(range(1, n_steps + 1), dynamic_ncols=True):
            # 获取当前控制速度
            control_vel = control_velocity.to(device)

            # 5. 准备输入图结构
            states = state_history[None]  # (1, n_his, max_nobj, 3)

            # 动作张量 - 只在控制点位置有速度
            states_delta = torch.zeros(1, self.max_nobj, 3, device=device)
            states_delta[:, sampled_control_mask] = control_vel

            # 属性 - 区分控制和普通粒子
            attrs = torch.zeros(1, self.max_nobj, 2, device=device)
            attrs[:, :, 0] = ~sampled_control_mask  # 普通粒子
            attrs[:, :, 1] = sampled_control_mask  # 控制粒子

            # 掩码
            obj_mask = torch.ones(1, self.max_nobj, dtype=bool, device=device)

            # 6. 构建邻接关系
            adj_thresh = self.adj_thresh
            Rr, Rs = construct_edges_obj(
                state_history[-1],
                adj_thresh,
                mask=obj_mask[0],
                topk=self.topk
            )

            # 7. 添加物理属性（碰撞距离和弹性模量）
            collider_distance = torch.clamp(collider_z[-1], -adj_thresh, 0.0)[..., None]
            log_E = torch.log(sampled_Es)[:, None]  # (max_nobj, 1)

            # 8. 创建输入图
            graph = {
                "state": states,
                "action": states_delta,
                "attrs": attrs,
                "p_instance": torch.ones(1, self.max_nobj, 1, device=device),
                "obj_mask": obj_mask,
                "Rr": Rr[None],
                "Rs": Rs[None],
                "collider_distance": collider_distance[None],
                "log_E": log_E[None]
            }

            # 9. 进行预测
            pred_state, _ = model(**graph)  # (1, max_nobj, 3)

            # 10. 更新状态历史
            new_state = pred_state[0].detach()
            state_history = torch.cat([state_history[1:], new_state.unsqueeze(0)], dim=0)
            collider_z = torch.cat([collider_z[1:], new_state[:, -1].unsqueeze(0)], dim=0)

            # 11. 存储结果
            pred_positions[i] = new_state
            pred_velocities.append(states_delta[0].cpu())

        # 12. 将结果扩展到所有粒子
        full_pred_positions = torch.zeros(n_steps + 1, len(xyz_0), 3, device=device)
        for i in range(n_steps + 1):
            full_pred_positions[i, fps_idx] = pred_positions[i]
            # 如果需要对非关键粒子进行插值，可以在这里添加

        return {
            'pred_positions': full_pred_positions.cpu(),
            'sampled_positions': pred_positions.cpu(),
            'sampled_indices': fps_idx.cpu(),
            'pred_velocities': torch.stack(pred_velocities).cpu(),
            'control_mask': sampled_control_mask.cpu(),
            'elastic_modulus': sampled_Es.cpu()
        }

    @torch.no_grad()
    def rollout_gripper_pred_v2(self, xyz_0, Es, control_velocity, control_mask, n_steps, xyz_gt):
        """
        完整版滚动预测（保留FPS+控制点优先+返回调试信息）

        参数:
            xyz_0: 初始粒子位置 (n_particles, 3)
            Es: 弹性模量 (n_particles)
            control_velocity: 控制速度 (3,)
            control_mask: 控制点掩码 (n_particles)
            n_steps: 预测步数（总步长n_steps+1）

        返回:
            {
                'pred_positions': 完整预测序列 (n_steps+1, n_particles, 3),
                'sampled_positions': 采样粒子预测 (n_steps+1, max_nobj, 3),
                'sampled_indices': 采样粒子索引 (max_nobj,),
                'pred_velocities': 预测速度序列 (n_steps, max_nobj, 3),
                'control_mask': 采样粒子控制掩码 (max_nobj,),
                'elastic_modulus': 采样粒子弹性模量 (max_nobj,)
            }
        """
        # 1. 初始化配置
        device = self.device
        n_his = 3
        max_nobj = self.max_nobj
        control_mask = control_mask.to(device).to(torch.bool)

        # 2. 两阶段采样（控制点优先 + FPS）
        control_idx = torch.where(control_mask)[0]
        non_control_idx = torch.where(~control_mask)[0]

        # 强制保留所有控制点（不超过max_nobj）
        n_control = min(len(control_idx), max_nobj)
        selected_control = control_idx[:n_control]
        remaining_slots = max_nobj - n_control

        # 对非控制点进行FPS采样
        if remaining_slots > 0 and len(non_control_idx) > 0:
            non_control_pos = xyz_0[non_control_idx].cpu()
            fps_idx = farthest_point_sampler(
                non_control_pos.unsqueeze(0),
                remaining_slots,
                start_idx=0
            )[0].to(device)
            selected_non_control = non_control_idx[fps_idx]
            selected_idx = torch.cat([selected_control, selected_non_control])
        else:
            selected_idx = selected_control

        # 3. 初始化状态和物理属性
        sampled_pos = xyz_0[selected_idx].to(device)
        sampled_Es = Es[selected_idx].to(device)
        state_history = sampled_pos.unsqueeze(0).repeat(n_his, 1, 1)  # (n_his, max_nobj, 3)

        # 4. 创建控制掩码（前n_control个为控制点）
        sampled_control_mask = torch.zeros(max_nobj, dtype=bool, device=device)
        sampled_control_mask[:n_control] = True

        # 5. 初始化结果存储
        pred_positions = torch.zeros(n_steps + 1, max_nobj, 3, device=device)
        pred_positions[0] = sampled_pos
        pred_velocities = []
        Rs_list = []
        Rr_list = []

        # 6. 滚动预测循环
        for i in range(1, n_steps + 1):
            # 准备动作（仅控制点有速度）
            vel = control_velocity.to(device)
            states_delta = torch.zeros(max_nobj, 3, device=device)
            states_delta[sampled_control_mask] = vel * 0.04
            pred_velocities.append(states_delta.clone())

            Rr, Rs = construct_edges_obj(state_history[-1], self.adj_thresh, topk=self.topk)
            Rs_list.append(Rs.cpu().numpy())
            Rr_list.append(Rr.cpu().numpy())
            # 构建输入图
            graph = {
                "state": state_history.unsqueeze(0),
                "action": states_delta.unsqueeze(0),
                "attrs": torch.stack([~sampled_control_mask, sampled_control_mask], dim=1).unsqueeze(0).float(),
                "p_instance": torch.ones(max_nobj, 1, device=device).unsqueeze(0),
                "obj_mask": torch.ones(max_nobj, dtype=bool, device=device).unsqueeze(0),
                "Rr": Rr.unsqueeze(0),
                "Rs": Rs.unsqueeze(0),
                "collider_distance": torch.clamp(state_history[-1, :, -1], -self.adj_thresh, 0.0).unsqueeze(
                    -1).unsqueeze(0),
                "log_E": torch.log(sampled_Es).unsqueeze(-1).unsqueeze(0)
            }

            # 模型预测
            pred_state, _ = self.model(**graph)
            state_history = torch.cat([state_history[1:], pred_state[0].unsqueeze(0)], dim=0)

            pred_positions[i] = pred_state[0]

        # 7. 映射回完整粒子空间
        full_pred_positions = torch.zeros(n_steps + 1, len(xyz_0), 3, device=device)
        full_pred_positions[:, selected_idx] = pred_positions

        # 8. 整理返回结果
        return {
            'pred_positions': full_pred_positions.cpu(),
            'sampled_positions': pred_positions.cpu(),
            'sampled_indices': selected_idx.cpu(),
            'pred_velocities': torch.stack(pred_velocities).cpu(),
            'control_mask': sampled_control_mask.cpu(),
            'elastic_modulus': sampled_Es.cpu(),
            'Rs': Rs_list,
            'Rr': Rr_list
        }

    @torch.no_grad()
    def rollout_gripper_pred_v3(self, xyz_0, Es, control_velocity, control_mask, n_steps, xyz_gt):
        """
        完整版滚动预测（控制点参与FPS+距离控制点mean最近的点作为起始点+返回调试信息）

        参数:
            xyz_0: 初始粒子位置 (n_particles, 3)
            Es: 弹性模量 (n_particles)
            control_velocity: 控制速度 (3,)
            control_mask: 控制点掩码 (n_particles)
            n_steps: 预测步数（总步长n_steps+1）

        返回:
            {
                'pred_positions': 完整预测序列 (n_steps+1, n_particles, 3),
                'sampled_positions': 采样粒子预测 (n_steps+1, max_nobj, 3),
                'sampled_indices': 采样粒子索引 (max_nobj,),
                'pred_velocities': 预测速度序列 (n_steps, max_nobj, 3),
                'control_mask': 采样粒子控制掩码 (max_nobj,),
                'elastic_modulus': 采样粒子弹性模量 (max_nobj,)
            }
        """
        # 1. 初始化配置
        device = self.device
        n_his = 3
        max_nobj = self.max_nobj
        control_mask = control_mask.to(device).to(torch.bool)

        # 2. 改进的FPS采样（控制点参与采样，起始点为距离控制点mean最近的点）
        control_idx = torch.where(control_mask)[0]

        # 计算控制点的中心点
        if len(control_idx) > 0:
            control_points = xyz_0[control_idx]
            control_center = torch.mean(control_points, dim=0)

            # 计算所有粒子到控制中心的距离
            distances = torch.norm(control_points - control_center, dim=1)


            # 找到距离控制中心最近的粒子作为FPS起始点
            start_idx = torch.argmin(distances)
            start_idx = control_idx[start_idx]
        else:
            # 如果没有控制点，随机选择一个起始点
            start_idx = torch.randint(0, len(xyz_0), (1,)).item()

        # 执行FPS采样（包含控制点）
        fps_idx = farthest_point_sampler(
            xyz_0.unsqueeze(0).cpu(),
            min(max_nobj, len(xyz_0)),
            start_idx=start_idx
        )[0].to(device)

        selected_idx = fps_idx
        sampled_control_mask = control_mask[selected_idx]

        # 3. 初始化状态和物理属性
        sampled_pos = xyz_0[selected_idx].to(device)
        sampled_Es = Es[selected_idx].to(device)
        state_history = sampled_pos.unsqueeze(0).repeat(n_his, 1, 1)  # (n_his, max_nobj, 3)

        # 4. 确保采样后的控制点数量不超过max_nobj
        # n_control = min(torch.sum(sampled_control_mask).item(), max_nobj)
        # sampled_control_mask = torch.zeros(max_nobj, dtype=bool, device=device)
        # sampled_control_mask[:n_control] = True

        # 5. 初始化结果存储
        pred_positions = torch.zeros(n_steps + 1, max_nobj, 3, device=device)
        pred_positions[0] = sampled_pos
        pred_velocities = []
        Rs_list = []
        Rr_list = []

        # 6. 滚动预测循环
        for i in range(1, n_steps + 1):
            # 准备动作（仅控制点有速度）
            vel = control_velocity.to(device)
            states_delta = torch.zeros(max_nobj, 3, device=device)
            states_delta[sampled_control_mask] = vel * 0.04
            pred_velocities.append(states_delta.clone())

            Rr, Rs = construct_edges_obj(state_history[-1], self.adj_thresh, topk=self.topk)
            Rs_list.append(Rs.cpu().numpy())
            Rr_list.append(Rr.cpu().numpy())

            # 构建输入图
            graph = {
                "state": state_history.unsqueeze(0),
                "action": states_delta.unsqueeze(0),
                "attrs": torch.stack([~sampled_control_mask, sampled_control_mask], dim=1).unsqueeze(0).float(),
                "p_instance": torch.ones(max_nobj, 1, device=device).unsqueeze(0),
                "obj_mask": torch.ones(max_nobj, dtype=bool, device=device).unsqueeze(0),
                "Rr": Rr.unsqueeze(0),
                "Rs": Rs.unsqueeze(0),
                "collider_distance": torch.clamp(state_history[-1, :, -1], -self.adj_thresh).unsqueeze(
                    -1).unsqueeze(0),
                "log_E": torch.log(sampled_Es).unsqueeze(-1).unsqueeze(0)
            }

            # 模型预测
            pred_state, _ = self.model(**graph)
            state_history = torch.cat([state_history[1:], pred_state[0].unsqueeze(0)], dim=0)

            pred_positions[i] = pred_state[0]

        # 7. 映射回完整粒子空间
        full_pred_positions = torch.zeros(n_steps + 1, len(xyz_0), 3, device=device)
        full_pred_positions[:, selected_idx] = pred_positions

        # 8. 整理返回结果
        return {
            'pred_positions': full_pred_positions.cpu(),
            'sampled_positions': pred_positions.cpu(),
            'sampled_indices': selected_idx.cpu(),
            'pred_velocities': torch.stack(pred_velocities).cpu(),
            'control_mask': sampled_control_mask.cpu(),
            'elastic_modulus': sampled_Es.cpu(),
            'Rs': Rs_list,
            'Rr': Rr_list
        }

    @torch.no_grad
    def finetune(self, xyz_0, eef_xyz, n_steps, xyz_gt, norm_E=None, friction=None):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        # fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        # fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]

        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[fps_all_idx]
        eef_num = 1 if len(eef_xyz.shape) == 2 else 2
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)

        particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        # particle_pos_0, fps_idx_second = fps_all_pos.clone(), torch.arange(len(fps_all_pos))
        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # norm_E = norm_E[fps_idx_second]
        # norm_E = norm_E[None, ..., None]

        # results to store
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []

        Rr, Rs, rels = None, None, None
        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            # assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            # assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     # TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + eef_num, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + eef_num, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + eef_num), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, nobj:] = 1

            obj_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)
            if i == 1:
                Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
                                                           mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
                                                           connect_all=self.connect_all, return_rels=True)
                Rr = Rr[None]
                Rs = Rs[None]
            rels_list.append(rels)

            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            # input_E = torch.zeros_like(collider_distance)
            # input_E[:] = float(norm_E[0])
            # input_E[0, -1] = 1.0
            logE = torch.zeros_like(collider_distance)
            fps_logE = norm_E[fps_all_idx][fps_idx_second]
            logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = friction[0]

            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)
            pred_pos.append(pred_state[0].cpu())

            # TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # fps_all_pos = all_pos[fps_all_idx]
            # fps_all_pos = pred_state[0].detach()
            # fps_all_pos = xyz_gt[i][fps_all_idx]

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list

    # @torch.no_grad
    def finetune_phys(self, xyz_0, eef_xyz, n_steps, xyz_gt, norm_E, friction):

        model = self.model
        device = self.device

        all_pos = xyz_0
        # fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        # fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[fps_all_idx]
        eef_num = 1 if len(eef_xyz.shape) == 2 else 2
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)

        particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # results to store
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []

        Rr, Rs, rels = None, None, None
        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + eef_num, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + eef_num, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + eef_num), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, nobj:] = 1

            obj_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)
            if i == 1:
                Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
                                                           mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
                                                           connect_all=self.connect_all, return_rels=True)
                Rr = Rr[None]
                Rs = Rs[None]
            rels_list.append(rels)

            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            # input_E = torch.zeros_like(collider_distance)
            # input_E[:] = float(norm_E[0])
            # input_E[0, -1] = 1.0
            logE = torch.zeros_like(collider_distance)
            fps_logE = norm_E[fps_all_idx][fps_idx_second]
            logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = friction[0]

            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)
            pred_pos.append(pred_state[0].cpu())

            # TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # fps_all_pos = xyz_gt[i][fps_all_idx]

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list

    @torch.no_grad
    def rollout_gripper_pred_2(self, xyz_0, eef_xyz, n_steps, xyz_gt, norm_E=None, friction=None):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        # fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        # fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]

        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[fps_all_idx]
        eef_num = 1 if len(eef_xyz.shape) == 2 else 2
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)

        particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # norm_E = norm_E[fps_idx_second]
        # norm_E = norm_E[None, ..., None]

        # results to store
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        eef_delta = torch.zeros(1, 3).to(device)
        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []

        Rr, Rs, rels = None, None, None
        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            # assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            # assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     # TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + eef_num, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + eef_num, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + eef_num), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, nobj] = 1

            obj_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)
            if i == 1:
                Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
                                                     mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
                                                     connect_all=self.connect_all, return_rels=True)
                Rr = Rr[None]
                Rs = Rs[None]
            rels_list.append(rels)

            collider_distance = 1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            # input_E = torch.zeros_like(collider_distance)
            # input_E[:] = float(norm_E[0])
            # input_E[0, -1] = 1.0
            logE = torch.zeros_like(collider_distance)
            fps_logE = norm_E[fps_all_idx][fps_idx_second]
            logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = friction[0]

            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)
            pred_pos.append(pred_state[0].cpu())

            # TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # fps_all_pos = all_pos[fps_all_idx]
            # fps_all_pos = pred_state[0].detach()
            # fps_all_pos = xyz_gt[i][fps_all_idx]

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list


    @torch.no_grad
    def rollout_mpm_pred(self, xyz_0, eef_xyz, n_steps, xyz_gt):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        # fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        # fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_pos = all_pos[fps_all_idx]

        eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
        eef_pos = eef_xyz[0]  # (1, 3)

        particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, n_particles, 3)

        # results to store
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps,5 n_bones, 3)
        eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        eef_delta = torch.zeros(1, 3).to(device)
        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            # assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            # assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     # TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue

            eef_pos_this_step = eef_xyz[i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + 1, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + 1, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + 1, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + 1), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            eef_mask[:, nobj] = 1

            obj_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)

            Rr, Rs = construct_edges_from_states(states[0, -1], self.adj_thresh,
                                                 mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
                                                 connect_all=self.connect_all)
            Rr = Rr[None]
            Rs = Rs[None]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)
            pred_pos.append(pred_state[0].cpu())

            # TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # fps_all_pos = all_pos[fps_all_idx]
            fps_all_pos = xyz_gt[i][fps_all_idx]
            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, xyz_bones, eef, pred_pos, gt_pos

    @torch.no_grad
    def rollout_mpm(self, xyz_0, eef_xyz, n_steps, xyz_gt):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        # fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        # fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_pos = all_pos[fps_all_idx]
        fps_all_pos_history = fps_all_pos[None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, n_particles, 3)

        eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
        eef_pos = eef_xyz[0]  # (1, 3)

        particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())


        # results to store
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps,5 n_bones, 3)
        eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        eef_delta = torch.zeros(1, 3).to(device)
        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     # TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue

            eef_pos_this_step = eef_xyz[i]
            eef_delta = eef_pos_this_step - eef_pos

            particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())


            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history[:, fps_idx]
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + 1, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + 1, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + 1, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + 1), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            eef_mask[:, nobj] = 1

            obj_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)

            Rr, Rs = construct_edges_from_states(states[0, -1], self.adj_thresh,
                                                 mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
                                                 connect_all=self.connect_all)
            Rr = Rr[None]
            Rs = Rs[None]

            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)
            pred_pos.append(pred_state[0].cpu())

            # TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # fps_all_pos = all_pos[fps_all_idx]
            fps_all_pos = xyz_gt[i][fps_all_idx]
            gt_pos.append(fps_all_pos[fps_idx].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, xyz_bones, eef, pred_pos, gt_pos

    @torch.no_grad
    def rollout(self, xyz_0, rgb_0, quat_0, opa_0, eef_xyz, n_steps, inlier_idx_all):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
        fps_all_pos_history = fps_all_pos[None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, n_particles, 3)

        eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
        eef_pos = eef_xyz[0]  # (1, 3)

        particle_pos_0, fps_idx_0 = self.downsample_vertices(fps_all_pos.clone())

        # results to store
        quat = quat_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 4)
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        rgb = rgb_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        opa = opa_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 1)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps, n_bones, 3)
        eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        eef_delta = torch.zeros(1, 3).to(device)

        key_point = []

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     quat[i] = quat[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     rgb[i] = rgb[i - 1].clone()
            #     opa[i] = opa[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     #TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue

            eef_pos_this_step = eef_xyz[i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos.clone()[fps_idx_0]
            fps_idx = fps_idx_0.clone()

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history[:, fps_idx]
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + 1, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + 1, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + 1, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + 1), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            eef_mask[:, nobj] = 1

            obj_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            #TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)

            Rr, Rs = construct_edges_from_states(states[0, -1], self.adj_thresh, 
                            mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk, connect_all=self.connect_all)
            Rr = Rr[None]
            Rs = Rs[None]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)

            #TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step


            # interpolate all_pos and particle_pos
            all_pos, all_rot, _ = interpolate_motions(
                bones=particle_pos,
                motions=pred_state[0] - particle_pos,
                relations=relations_to_matrix(Rr, Rs)[:nobj, :nobj],
                xyz=all_pos,
                quat=quat[i - 1].to(device),
            )
            fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)

            quat[i] = all_rot.cpu()
            xyz[i] = all_pos.cpu()
            rgb[i] = rgb[i - 1].clone()
            opa[i] = opa[i - 1].clone()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, rgb, quat, opa, xyz_bones, eef, key_point

    def update_fps_points(self, current_pos, prev_fps_idx, adj_thresh=0.1):
        """动态调整 FPS 点，避免断裂"""
        n_fps = len(prev_fps_idx)
        current_fps_pos = current_pos[prev_fps_idx]

        # 检查哪些关键点距离过远（断裂）
        dist_matrix = torch.cdist(current_fps_pos, current_fps_pos)  # (n_fps, n_fps)
        is_valid = (dist_matrix < adj_thresh).any(dim=1)  # 每个关键点是否至少有一个邻居

        # 如果所有关键点仍然有效，则沿用之前的
        if is_valid.all():
            return prev_fps_idx

        # 否则，对断裂区域重新采样
        invalid_mask = ~is_valid
        new_fps_idx = farthest_point_sampler(
            current_pos[invalid_mask].cpu()[None],
            n_fps - invalid_mask.sum(),
            start_idx=0
        )[0]

        # 合并有效和新增的关键点
        final_fps_idx = torch.cat([
            prev_fps_idx[is_valid],
            new_fps_idx.to(device)
        ], dim=0)

        return final_fps_idx


    @torch.no_grad
    def rollout_gripper_3dgs_fps_fixed(self, xyz_0, quat_0, eef_xyz, n_steps, xyz_gt, xyz_FPS=None, norm_E=None, fps_idx=None):
        # 输入参数:
        # xyz_0: (n_particles, 3) 初始粒子位置
        # quat_0: (n_particles, 4) 初始粒子旋转(四元数)
        # eef_xyz: (n_steps, 3) 末端执行器轨迹
        # xyz_gt: (n_steps, n_particles, 3) 粒子位置真实值(用于调试)
        # n_steps: 总步数(包括初始状态)

        model = self.model
        device = self.device

        # 初始化所有粒子状态
        all_pos = xyz_0
        all_rot = quat_0.clone()

        # 降采样初始化
        # fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_idx = fps_idx
        fps_all_pos = all_pos[fps_all_idx]
        fps_all_pos_history = fps_all_pos[None].repeat(model.model_config['n_his'], 1, 1)
        # particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        particle_pos_0 = fps_all_pos
        fps_idx_second = torch.arange(len(fps_idx))

        # 初始化历史轨迹
        eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
        eef_pos = eef_xyz[0]  # (1, 3)
        # fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
        #                                                                1)  # (n_his, n_particles, 3)

        # 存储结果的张量
        xyz = torch.zeros(n_steps, *xyz_0.shape)  # (n_steps, n_particles, 3)
        quat = torch.zeros(n_steps, *quat_0.shape)  # (n_steps, n_particles, 4)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps, n_bones, 3)
        eef = torch.zeros(n_steps, 1, 3)  # (n_steps, 1, 3)

        # 初始化第0帧状态
        xyz[0] = all_pos.cpu()
        quat[0] = all_rot.cpu()
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()
        eef[0] = eef_pos.cpu()

        # 调试信息
        pred_pos = [particle_pos_0.cpu()]
        gt_pos = [fps_all_pos[fps_idx_second].cpu()]
        rels_list = []
        Rr, Rs, rels = None, None, None

        mean_particle = particle_pos_0[:, 0].mean()
        mask_indices = torch.where(particle_pos_0[:, 0] >= mean_particle)[0]

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            # 计算夹爪运动增量
            eef_pos_this_step = eef_xyz[i]
            eef_delta = eef_pos_this_step - eef_pos

            # 获取当前关键点状态
            # particle_pos = fps_all_pos_history[-1]
            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())

            particle_pos = fps_all_pos.clone()[fps_idx_second]
            fps_idx = fps_idx_second

            # fps_idx = self.update_fps_points(fps_all_pos, fps_idx, adj_thresh=self.adj_thresh)
            # particle_pos = all_pos[fps_idx]

            # particle_pos_history = fps_all_pos_history
            particle_pos_history = fps_all_pos_history[:, fps_idx]
            nobj = particle_pos.shape[0]

            # 构建模型输入状态
            states = torch.zeros((1, self.n_his, nobj + 1, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + 1, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            # 设置属性和掩码
            attrs = torch.zeros((1, nobj + 1, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.  # 粒子属性
            attrs[:, nobj:, 1] = 1.  # 夹爪属性

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)
            state_mask = torch.ones((1, nobj + 1), dtype=bool, device=device)
            eef_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            eef_mask[:, nobj] = 1
            obj_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # 构建粒子关系图
            if i == 1:
                Rr, Rs, rels = construct_edges_from_states(
                    states[0, -1], self.adj_thresh,
                    mask=state_mask[0], tool_mask=eef_mask[0],
                    topk=self.topk, connect_all=self.connect_all,
                    return_rels=True
                )
                Rr, Rs = Rr[None], Rs[None]
            rels_list.append(rels)
            collider_distance = torch.clamp(states[:, -1, :, -1][..., None], -self.adj_thresh)

            # logE = torch.zeros_like(collider_distance)
            # logE[:] = float(norm_E[0])
            # logE[:] = 0.3
            # logE[0, mask_indices] = 0.5
            # logE[0, -1] = 1.0

            logE = torch.zeros_like(collider_distance)
            # logE[:] = float(norm_E[0])
            # logE[0, -1] = 1.0
            logE[0, :nobj] = norm_E
            logE[0, nobj:] = 1.0

            # 调用模型预测关键点运动
            graph = {
                "state": states,
                "action": states_delta,
                "attrs": attrs,
                "p_instance": p_instance,
                "obj_mask": obj_mask,
                "state_mask": state_mask,
                "eef_mask": eef_mask,
                "Rr": Rr,
                "Rs": Rs,
                "collider_distance": collider_distance,
                "logE": logE,
            }
            pred_state, _ = model(**graph)  # (1, nobj, 3)
            # relation = get_topk_indices(particle_pos, K=5)
            # weights = knn_weights_new(particle_pos, all_pos, K=5)

            relation = get_topk_indices(particle_pos, K=5)
            weights = knn_weights_new(particle_pos, all_pos, K=5)
            # 运动插值 - 根据关键点运动更新所有粒子
            all_pos, all_rot, _ = inter_motion(
                bones=particle_pos,  # 当前关键点位置
                motions=pred_state[0] - particle_pos,  # 关键点位移
                relations=relation,  # 关系矩阵
                xyz=all_pos,  # 所有粒子当前位置
                quat=all_rot,  # 所有粒子当前旋转
                weights=weights
            )

            # 更新历史状态
            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # 获取当前关键点位置 (从完整点云中采样)
            fps_all_pos = all_pos[fps_all_idx]
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[fps_idx_second][None]], dim=0)

            # 存储当前状态
            xyz[i] = all_pos.cpu()
            quat[i] = all_rot.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

            # 存储调试信息
            pred_pos.append(pred_state[0].cpu())
            # gt_pos.append(xyz_gt[i][fps_all_idx][fps_idx_second].cpu())

        # 返回结果: 所有粒子位置、旋转、关键点位置、末端执行器位置、调试信息
        return xyz, quat, xyz_bones, eef, pred_pos, gt_pos, rels_list

    @torch.no_grad
    def rollout_gripper_3dgs(self, xyz_0, quat_0, eef_xyz, n_steps, xyz_gt, xyz_FPS=None, norm_E=None, friction=None):
        # 输入参数:
        # xyz_0: (n_particles, 3) 初始粒子位置
        # quat_0: (n_particles, 4) 初始粒子旋转(四元数)
        # eef_xyz: (n_steps, 3) 末端执行器轨迹
        # xyz_gt: (n_steps, n_particles, 3) 粒子位置真实值(用于调试)
        # n_steps: 总步数(包括初始状态)

        model = self.model
        device = self.device

        # 初始化所有粒子状态
        eef_num = 1 if len(eef_xyz.shape) == 2 else 2
        all_pos = xyz_0
        all_rot = quat_0.clone()

        # 降采样初始化
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        # fps_all_idx = torch.arange(len(xyz_0))
        fps_all_pos = all_pos[fps_all_idx]
        fps_all_pos_history = fps_all_pos[None].repeat(model.model_config['n_his'], 1, 1)
        particle_pos_0, fps_idx_second = self.downsample_vertices(fps_all_pos.clone())

        # 初始化历史轨迹
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)
        # fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
        #                                                                1)  # (n_his, n_particles, 3)
        # 存储结果的张量
        xyz = torch.zeros(n_steps, *xyz_0.shape)  # (n_steps, n_particles, 3)
        quat = torch.zeros(n_steps, *quat_0.shape)  # (n_steps, n_particles, 4)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps, n_bones, 3)
        eef = torch.zeros(n_steps, eef_num, 3)  # (n_steps, 1, 3)

        # 初始化第0帧状态
        xyz[0] = all_pos.cpu()
        quat[0] = all_rot.cpu()
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()
        eef[0] = eef_pos.cpu()

        # 调试信息
        pred_pos = [particle_pos_0.cpu()]
        gt_pos = [fps_all_pos[fps_idx_second].cpu()]
        rels_list = []
        Rr, Rs, rels = None, None, None

        mean_particle = particle_pos_0[:, 1].mean()
        mask_indices = torch.where(particle_pos_0[:, 1] >= mean_particle)[0]

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            # 计算夹爪运动增量
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # 获取当前关键点状态
            # particle_pos = fps_all_pos_history[-1]
            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())

            particle_pos = fps_all_pos.clone()[fps_idx_second]
            fps_idx = fps_idx_second

            # fps_idx = self.update_fps_points(fps_all_pos, fps_idx, adj_thresh=self.adj_thresh)
            # particle_pos = all_pos[fps_idx]

            # particle_pos_history = fps_all_pos_history
            particle_pos_history = fps_all_pos_history[:, fps_idx]
            nobj = particle_pos.shape[0]

            # 构建模型输入状态
            states = torch.zeros((1, self.n_his, nobj + eef_num, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + eef_num, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            # 设置属性和掩码
            attrs = torch.zeros((1, nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.  # 粒子属性
            attrs[:, nobj:, 1] = 1.  # 夹爪属性

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)
            state_mask = torch.ones((1, nobj + eef_num), dtype=bool, device=device)
            eef_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, nobj:] = 1
            obj_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # 构建粒子关系图
            if i == 1:
                Rr, Rs, rels = construct_edges_from_states(
                    states[0, -1], self.adj_thresh,
                    mask=state_mask[0], tool_mask=eef_mask[0],
                    topk=self.topk, connect_all=self.connect_all,
                    return_rels=True
                )
                Rr, Rs = Rr[None], Rs[None]
            rels_list.append(rels)
            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., None], -self.adj_thresh)

            # logE = torch.zeros_like(collider_distance)
            # logE[:] = float(norm_E[0])
            # logE[:] = 0.2
            # logE[0, mask_indices] = 0.8
            # logE[0, -1] = 1.0

            logE = torch.zeros_like(collider_distance)
            fps_logE = norm_E[fps_all_idx][fps_idx_second]
            logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = friction[0]

            # 调用模型预测关键点运动
            graph = {
                "state": states,
                "action": states_delta,
                "attrs": attrs,
                "p_instance": p_instance,
                "obj_mask": obj_mask,
                "state_mask": state_mask,
                "eef_mask": eef_mask,
                "Rr": Rr,
                "Rs": Rs,
                "collider_distance": collider_distance,
                "logE": logE,
                "friction":frictions
            }
            pred_state, _ = model(**graph)  # (1, nobj, 3)

            relation = get_topk_indices(particle_pos, K=5)
            weights = knn_weights_new(particle_pos, all_pos, K=5)
            # 运动插值 - 根据关键点运动更新所有粒子
            all_pos, all_rot, _ = inter_motion(
                bones=particle_pos,  # 当前关键点位置
                motions=pred_state[0] - particle_pos,  # 关键点位移
                relations=relation,  # 关系矩阵
                xyz=all_pos,  # 所有粒子当前位置
                quat=all_rot,  # 所有粒子当前旋转
                weights=weights
            )

            # 更新历史状态
            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step

            # 获取当前关键点位置 (从完整点云中采样)
            fps_all_pos = all_pos[fps_all_idx]
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[fps_idx_second][None]], dim=0)

            # 存储当前状态
            xyz[i] = all_pos.cpu()
            quat[i] = all_rot.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

            # 存储调试信息
            pred_pos.append(pred_state[0].cpu())
            # gt_pos.append(xyz_gt[i][fps_all_idx][fps_idx_second].cpu())

        # 返回结果: 所有粒子位置、旋转、关键点位置、末端执行器位置、调试信息
        return xyz, quat, xyz_bones, eef, pred_pos, gt_pos, rels_list

    @torch.no_grad
    def rollout_gripper(self, xyz_0, rgb_0, quat_0, opa_0, eef_xyz, n_steps, inlier_idx_all):
        # xyz_0: (n_particles, 3)
        # rgb_0: (n_particles, 3)
        # quat_0: (n_particles, 4)
        # opa_0: (n_particles, 1)
        # n_step: including the initial state and the final state (n_step - 1 steps in between)

        model = self.model
        device = self.device

        all_pos = xyz_0
        fps_all_idx = farthest_point_sampler(xyz_0.cpu()[inlier_idx_all][None], 1000, start_idx=0)[0]
        fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
        fps_all_pos_history = fps_all_pos[None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, n_particles, 3)

        eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
        eef_pos = eef_xyz[0]  # (1, 3)

        particle_pos_0, fps_idx_0 = self.downsample_vertices(fps_all_pos.clone())

        # results to store
        quat = quat_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 4)
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        rgb = rgb_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        opa = opa_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 1)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps, n_bones, 3)
        eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)

        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        eef_delta = torch.zeros(1, 3).to(device)

        key_point = []

        for i in tqdm(range(1, n_steps), dynamic_ncols=True):
            assert torch.allclose(fps_all_pos, fps_all_pos_history[-1])
            assert torch.allclose(eef_pos, eef_pos_history[-1])

            # if torch.norm(eef_xyz[i] - eef_pos) < self.dist_thresh:
            #     # rot[i] = rot[i - 1].clone()
            #     quat[i] = quat[i - 1].clone()
            #     xyz[i] = xyz[i - 1].clone()
            #     rgb[i] = rgb[i - 1].clone()
            #     opa[i] = opa[i - 1].clone()
            #     xyz_bones[i] = xyz_bones[i - 1].clone()
            #     eef[i] = eef[i - 1].clone()
            #
            #     #TODO DEBUG
            #     particle_pos_temp, fps_idx_temp = self.downsample_vertices(fps_all_pos.clone())
            #     key_point.append(particle_pos_temp)
            #     continue

            eef_pos_this_step = eef_xyz[i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos.clone()[fps_idx_0]
            fps_idx = fps_idx_0.clone()

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history[:, fps_idx]
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + 1, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + 1, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + 1, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + 1), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            eef_mask[:, nobj] = 1

            obj_mask = torch.zeros((1, nobj + 1), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            #TODO debug
            # states[0, :, -1, -1] = states[0, :, :-1, -1].mean(axis=1)

            Rr, Rs = construct_edges_from_states(states[0, -1], self.adj_thresh,
                            mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk, connect_all=self.connect_all)
            Rr = Rr[None]
            Rs = Rs[None]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)

            #TODO DEBUG
            # pred_state[0, :, -1] = graph['state'][0, -1, :-1, -1]
            # key_point.append(pred_state[0])

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0)
            eef_pos = eef_pos_this_step


            # interpolate all_pos and particle_pos
            all_pos, all_rot, _ = interpolate_motions(
                bones=particle_pos,
                motions=pred_state[0] - particle_pos,
                relations=relations_to_matrix(Rr, Rs)[:nobj, :nobj],
                xyz=all_pos,
                quat=quat[i - 1].to(device),
            )
            fps_all_pos = all_pos[inlier_idx_all][fps_all_idx]
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], fps_all_pos[None]], dim=0)

            quat[i] = all_rot.cpu()
            xyz[i] = all_pos.cpu()
            rgb[i] = rgb[i - 1].clone()
            opa[i] = opa[i - 1].clone()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, rgb, quat, opa, xyz_bones, eef, key_point

    def load_sep_params(self, output_dir):
        params_dir = os.path.join(output_dir, 'params.npz')
        original_params = np.load(params_dir) if os.path.exists(params_dir) else None
        separate_dirs = sorted(glob.glob(os.path.join(output_dir, 'params_*.npz')),
                               key=lambda x: int(os.path.splitext(os.path.basename(x))[0][7:]))
        params = None
        if len(separate_dirs) > 0:
            start_params = None
            is_params = True
            if os.path.exists(os.path.join(output_dir, 'params_0.npz')):
                is_params = False
                start_params_load = np.load(os.path.join(output_dir, 'params_0.npz'))
                start_params = {}
                for k in start_params_load.files:
                    start_params[k] = start_params_load[k][None]
                separate_dirs.remove(os.path.join(output_dir, 'params_0.npz'))
            else:
                start_params = {k: original_params[k] for k in original_params.files}
            separate_params = []
            for separate_dir in separate_dirs:
                separate_params.append(np.load(separate_dir))
            sep_stack = {}
            for k in separate_params[0].files:
                sep_stack[k] = np.stack([params[k] for params in separate_params])
            params = {k: np.concatenate((start_params[k], sep_stack[k])) for k in sep_stack.keys()}
            for k in start_params.keys():
                if k not in sep_stack.keys():
                    params[k] = start_params[k] if is_params else start_params[k][0]
            # params['logit_opacities'] = start_params['logit_opacities'] if is_params else start_params['logit_opacities'][0]
            if len(params['means3D']) > 600:
                for k in params.keys():
                    params[k] = params[k][1:]

        else:
            params_dir = os.path.join(output_dir, 'params.npz')
            if not os.path.exists(params_dir):
                raise ValueError(f'Params dir {params_dir} not found')
            params = np.load(params_dir)
        return dict(params)

    def collect_scene_data(self, data_path, output_path):
        output_path_list = output_path.split('/')
        render_ckpts_path = '/'.join(output_path_list)
        # params = dict(np.load(os.path.join(render_ckpts_path, "params.npz")))
        params = self.load_sep_params(render_ckpts_path)
        params = {k: torch.tensor(v).cuda().float() for k, v in params.items()}

        xyz_0 = params['means3D'][0]
        rgb_0 = params['rgb_colors'][0]
        quat_0 = torch.nn.functional.normalize(params['unnorm_rotations'][0])
        opa_0 = torch.sigmoid(params['logit_opacities'])
        scales_0 = torch.exp(params['log_scales'])

        low_opa_idx = opa_0[:, 0] < 0.1
        xyz_0 = xyz_0[~low_opa_idx]
        rgb_0 = rgb_0[~low_opa_idx]
        quat_0 = quat_0[~low_opa_idx]
        opa_0 = opa_0[~low_opa_idx]
        scales_0 = scales_0[~low_opa_idx]

        outliers = None
        new_outlier = None
        rm_iter = 0
        inlier_idx_all = np.arange(len(xyz_0))
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(xyz_0.detach().cpu().numpy())
        while new_outlier is None or len(new_outlier.points) > 0:
            _, inlier_idx = pcd.remove_statistical_outlier(
                nb_neighbors = 50, std_ratio = 2.0 + rm_iter * 0.5
            )
            inlier_idx_all = inlier_idx_all[inlier_idx]
            new_pcd = pcd.select_by_index(inlier_idx)
            new_outlier = pcd.select_by_index(inlier_idx, invert=True)
            if outliers is None:
                outliers = new_outlier
            else:
                outliers += new_outlier
            pcd = new_pcd
            rm_iter += 1

        eef_xyz, frame_idxs = load_eef_pos(data_path, output_path)

        #debug
        # eef_xyz[:,0, -1] = xyz_0[:, -1].max().item()
        # eef_xyz = eef_xyz[100:]

        eef_xyz = np.load('./gen_data/0/eef_pos.npy')
        eef_xyz = eef_xyz[:, None, :]
        eef_xyz[:, 0, -1] = xyz_0[:, -1].max().item()
        #再加50
        speed = eef_xyz[-1] - eef_xyz[-2]
        for  i in range(50):
            eef_xyz = np.concatenate([eef_xyz, (eef_xyz[-1] + speed)[None]])
        # eef_z = xyz_0[:, -1].max().item()
        # eef_xyz = eef_xyz[165:]
        # eef_xyz_0 = eef_xyz[165]
        # eef_xyz_1 = eef_xyz[166]
        # delta = eef_xyz_1 - eef_xyz_0
        # distance = np.linalg.norm(delta)
        # direction = delta / (distance + 1e-6)
        # velocity = direction * 0.28
        # for i in range(1, len(eef_xyz)):
        #     eef_xyz[i] = eef_xyz[i-1] + velocity
        # eef_xyz[:, 0, -1] = eef_z

        eef_xyz = torch.from_numpy(eef_xyz).float().to(self.device)


        #TODO DEBUG
        # eef_xyz[:, 0, -1] = xyz_0[:, -1].mean()

        n_steps = min(len(eef_xyz), 1000)

        xyz, rgb, quat, opa, xyz_bones, eef, key_point = self.rollout(
                xyz_0, rgb_0, quat_0, opa_0, eef_xyz, n_steps, inlier_idx_all)


        # interpolate smoothly
        change_points = (xyz - torch.concatenate([xyz[0:1], xyz[:-1]], dim=0)).norm(dim=-1).sum(dim=-1).nonzero().squeeze(1)
        change_points = torch.cat([torch.tensor([0]), change_points])
        for i in range(1, len(change_points)):
            start = change_points[i - 1]
            end = change_points[i]
            if end - start < 2:  # 0 or 1
                continue
            xyz[start:end] = torch.lerp(xyz[start][None], xyz[end][None], torch.linspace(0, 1, end - start + 1).to(xyz.device)[:, None, None])[:-1]
            rgb[start:end] = torch.lerp(rgb[start][None], rgb[end][None], torch.linspace(0, 1, end - start + 1).to(rgb.device)[:, None, None])[:-1]
            quat[start:end] = torch.lerp(quat[start][None], quat[end][None], torch.linspace(0, 1, end - start + 1).to(quat.device)[:, None, None])[:-1]
            opa[start:end] = torch.lerp(opa[start][None], opa[end][None], torch.linspace(0, 1, end - start + 1).to(opa.device)[:, None, None])[:-1]
            xyz_bones[start:end] = torch.lerp(xyz_bones[start][None], xyz_bones[end][None], torch.linspace(0, 1, end - start + 1).to(xyz_bones.device)[:, None, None])[:-1]
            eef[start:end] = torch.lerp(eef[start][None], eef[end][None], torch.linspace(0, 1, end - start + 1).to(eef.device)[:, None, None])[:-1]
        quat = torch.nn.functional.normalize(quat, dim=-1)

        scene_data = []  # Initialize a list to store rendering variables for each item
        vis_data = []
        for t in range(n_steps):
            rendervar = {
                'means3D': xyz[t],  
                'colors_precomp': rgb[t],
                'rotations': quat[t],
                'opacities': opa[t],
                'scales': scales_0,
                'means2D': torch.zeros_like(xyz[t]),
            }

            visvar = {
                'kp': xyz_bones[t].numpy(), # params['means3D'][t][fps_idx].detach().cpu().numpy(),
                'tool_kp': eef[t].numpy(), # eef_xyz[t].detach().cpu().numpy(),
            }
            scene_data.append(rendervar)
            vis_data.append(visvar)

        return scene_data, vis_data, key_point
