import sys
sys.path.append("gaussian-splatting")

# gaussian-splatting
from scene.gaussian_model_final import GaussianModel
from scene.cameras import Camera as GSCamera
from gaussian_renderer import render
from utils.sh_utils import eval_sh
from utils.system_utils import searchForMaxIteration
from utils.graphics_utils import focal2fov

# Utils
from utils.decode_param import *
from utils.transformation_utils import *
from utils.camera_view_utils import *
from utils.render_utils import *
from utils.Threedgs_general_utils import *
from utils.my_mpm_utils import *
from utils.my_utils import *

from my_loss_func import *
from mpm_solver_warp.warp_utils import *
import mpm_solver_warp.mpm_solver as warp_solver

def load_checkpoint_structure(data_path, sh_degree=3):
    gaussian = GaussianModel(sh_degree)
    filling_path = os.path.join(data_path, 'point_cloud_filling.ply')
    if os.path.exists(filling_path):
        gaussian.load_ply_filling(data_path)
        print('load ply')
    else:
        gaussian.load_structure_data_filling(data_path)
    return gaussian

def RT_to_cov(rot, scale):
    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
        actual_covariance = L @ L.transpose(1, 2)
        symm = strip_symmetric(actual_covariance)
        return symm
    scaling = 1.0
    covs = build_covariance_from_scaling_rotation(scale, scaling, rot)
    return covs

def mask_materials(materials, mask):
    new_materials = []
    for i in range(len(materials)):
        if mask[i]:
            new_materials.append(materials[i])
    return new_materials

class PipelineParamsNoparse:
    def __init__(self):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.debug = False

class MPMSimulator:
    def __init__(self,
                 data_path,
                 config_path,
                 output_path,
                 gripper_type='single_gripper',
                 backward=False,
                 device="cuda:0",
                 optimized_params_path=None,
                 ):
        os.makedirs(output_path, exist_ok=True)
        # path
        self.data_path = data_path
        self.config_path = config_path
        self.output_path = output_path
        self.calibrate_path = os.path.join(self.data_path, 'calibrate.pkl')
        self.meta_path = os.path.join(self.data_path, 'metadata.json')
        assert os.path.exists(self.data_path), "data does not exist!"
        assert os.path.exists(self.config_path), "config does not exist!"
        assert os.path.exists(self.calibrate_path), "calibrate does not exist!"
        assert os.path.exists(self.meta_path), "meta does not exist!"
        assert gripper_type in ['double_gripper', 'single_gripper', 'push'],"gripper_type error!"
        # cuda
        self.device = device
        self.backward = backward
        # MPM
        self.mpm_solver = warp_solver.MPM_Simulator_WARP(10, 10, 10, device=self.device)
        self.mpm_init_pos_list = []
        self.mpm_init_vol_list = []
        self.mpm_init_cov_list = []
        self.n_particle = None
        self.obj_num = None
        self.material = None
        # params
        (
            self.material_params,
            self.bc_params,
            self.time_params,
            self.preprocessing_params,
            self.camera_params,
        ) = decode_param_json(config_path)

        if optimized_params_path is not None:
            with open(optimized_params_path, "rb") as f:
                optimized_params = pickle.load(f)
                for k, v in optimized_params.items():
                    self.material_params[k] = v

        # point_num
        self.gs_num = None
        self.filter_num = self.preprocessing_params["filter_num"]

        # time/frame
        self.pre_move_frame = 10
        self.start_frame = 0
        self.end_frame = self.time_params["frame_num"] - 1
        self.substep_dt = self.time_params["substep_dt"]
        self.frame_dt = self.time_params["frame_dt"]
        self.step_per_frame = int(self.frame_dt / self.substep_dt)

        self.current_frame = self.start_frame
        self.current_substep = 0

        # background
        self.white_bg = False
        self.background = (
            torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
            if self.white_bg
            else torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda"))

        # gaussian
        self.part_field_knn = None
        self.structure_num = None
        self.active_sh_degree = None
        self.max_sh_degree = None
        self.scaling = None

        self.init_screen_points = None
        self.shs = None
        self.opacity = None
        self.colors_precomp = None
        # gt
        self.rotation_matrices = None
        self.gaussian_gt = None
        self.shift_tensor = None
        self.scale_origin = None
        self.original_mean_pos = None
        # eef
        self.gripper_type = gripper_type
        self.eef_pos = None
        self.eef_pos_ori = None
        self.gripper_mask = None
        self.gripper_mask_wp = None
        self.target_point_indice = None
        self.target_point1_indice = None
        self.target_point2_indice = None
        self.eef_k = self.preprocessing_params["eef_k"]
        self.eef_k_shift = self.preprocessing_params["eef_k_shift"]
        # bc
        self.need_collider = True
        self.ee_mask = None
        self.collider_z = 0
        # camera
        self.viewpoint_center_worldspace = None
        self.observant_coordinates = None
        self.current_camera = None
        # rasterize
        self.rasterize = None
        # train
        self.frame_num = None
        self.train_frame = None
        self.test_frame = None
        # loss
        self.chamfer_loss_weight = 1.0
        self.track_loss_weight = 0.0
        self.temp_loss = 0.0
        self.loss = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)
        self.chamfer_loss_total = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)
        self.track_loss_total = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)

        self.l2_loss = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)
        self.chamfer_loss = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)
        self.track_loss = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)
        self.grad_loss = wp.zeros((1,), dtype=wp.float32, requires_grad=self.backward)

        self.tape = wp.Tape()

        # save
        self.saved_x = None
        self.saved_v = None
        self.saved_c = None
        self.saved_f = None
        self.saved_covs = None
        self.saved_vols = None

    def init_scene(self, push_idx):
        # 加载MPM的gs数据
        gaussian = load_checkpoint_structure(self.data_path,sh_degree=3)
        self.structure_num = gaussian.structure_num
        self.active_sh_degree = gaussian.active_sh_degree
        self.max_sh_degree = gaussian.max_sh_degree
        self.scaling = gaussian.get_scaling
        print(gaussian.obj_num)

        # 加载控制点
        eef_pos = load_eef_pos(self.data_path, self.gripper_type,self.eef_k_shift)
        # 加载训练分割数据
        self.frame_num,self.train_frame,self.test_frame = load_split(self.data_path)
        self.start_frame = 0
        self.end_frame = self.frame_num - 1

        # 加载gt数据
        gaussian_gt = load_track_data(self.data_path)
        for k in gaussian_gt.keys():
            if k in ['object_points', 'object_colors', 'unnorm_rotations', 'seg_colors', 'logit_opacities','log_scales']:
                gaussian_gt[k] = torch.tensor(gaussian_gt[k], device='cuda', dtype=torch.float32)
        self.gaussian_gt = gaussian_gt

        # 加载材料
        particle_materials = [self.material_params['material']] * gaussian.obj_num
        self.material = self.material_params['material']
        # 数据适配gs渲染
        pipeline = PipelineParamsNoparse()
        pipeline.compute_cov3D_python = True
        params = load_params_from_gs(gaussian, pipeline,override_color = [0,1,0])
        init_pos = params["pos"]
        init_cov = params["cov3D_precomp"]
        init_screen_points = params["screen_points"]
        init_opacity = params["opacity"]
        init_shs = params["shs"]
        colors_precomp = params["colors_precomp"]
        # 纯track点情况下不需要筛选，如果回归gs嵌入可以再加上
        # mask = init_opacity[:, 0] > self.preprocessing_params["opacity_threshold"]
        # mask[:gaussian.structure_num] = True
        # init_pos = init_pos[mask]
        # init_cov = init_cov[mask]
        # init_opacity = init_opacity[mask]
        # init_screen_points = init_screen_points[mask]
        # init_shs = init_shs[mask]
        #particle_materials = mask_materials(particle_materials, mask)
        self.shs = init_shs
        self.colors_precomp = torch.tensor(colors_precomp, dtype=torch.float32,device=self.device)
        self.opacity = init_opacity
        self.init_screen_points = init_screen_points
        self.material_params['material'] = particle_materials

        #变换到MPM解算空间
        shift_tensor = torch.tensor([1.0, 1.0, 1.0], device="cuda")
        rotation_matrices = generate_rotation_matrices(
            torch.tensor(self.preprocessing_params["rotation_degree"]),
            self.preprocessing_params["rotation_axis"])
        rotated_pos = apply_rotations(init_pos, rotation_matrices)
        transformed_pos, scale_origin, original_mean_pos = transform2origin(rotated_pos)
        transformed_pos = shift2center111(transformed_pos, shift_tensor)

        if self.gripper_type in ['push', 'single_gripper']:
            eef_xyz_tensor = torch.tensor(eef_pos, device="cuda")
            transformed_eef_pos = (eef_xyz_tensor - original_mean_pos) * scale_origin
            self.eef_pos = shift2center111(transformed_eef_pos, shift_tensor)
            self.eef_pos_ori = eef_xyz_tensor
        elif self.gripper_type == 'double_gripper':
            eef_xyz_tensor_1 = torch.tensor(eef_pos[0], device="cuda")
            transformed_eef_pos = (eef_xyz_tensor_1 - original_mean_pos) * scale_origin
            eef_pos_1 = shift2center111(transformed_eef_pos, shift_tensor)
            eef_xyz_tensor_2 = torch.tensor(eef_pos[1], device="cuda")
            transformed_eef_pos = (eef_xyz_tensor_2 - original_mean_pos) * scale_origin
            eef_pos_2 = shift2center111(transformed_eef_pos, shift_tensor)
            self.eef_pos = torch.stack([eef_pos_1, eef_pos_2])
            self.eef_pos_ori = torch.stack([eef_xyz_tensor_1, eef_xyz_tensor_2])

        self.rotation_matrices = rotation_matrices
        self.shift_tensor = shift_tensor
        self.scale_origin = scale_origin
        self.original_mean_pos = original_mean_pos
        self.gs_num = transformed_pos.shape[0]

        # ee_mask
        ee_mask = np.zeros(len(transformed_pos), dtype=int)
        ee_bc = get_ee_bc(ee_mask)
        self.bc_params.extend(ee_bc)
        if self.gripper_type == 'double_gripper':
            ee_bc = get_ee_bc(ee_mask)
            self.bc_params.extend(ee_bc)
        self.gripper_mask_wp = wp.from_numpy(ee_mask, dtype=wp.int32, device=self.device)
        self.ee_mask = self.gripper_mask_wp

        if self.need_collider:
            shift_array = shift_tensor.detach().cpu().numpy()
            # collider_z = transformed_pos[:, -1].max() + 0.0001
            collider_z = (0.0 - original_mean_pos[-1]) * scale_origin + shift_array[-1]

            # 将低于桌面的点都抬升到桌面上
            transformed_pos = make_points_over_collider(transformed_pos,collider_z)

            # 调试信息，统计有多少点低于桌面
            print(f"z 轴低于桌面的点有: {torch.sum(transformed_pos[:, 2] > collider_z).item()}个, 一共有: {transformed_pos.shape[0]}个点")

            collider_bc = get_collider_bc(point_x=shift_array[0], point_y=shift_array[1], point_z=collider_z, friction=self.material_params["friction"])
            self.bc_params.extend(collider_bc)
            self.collider_z = collider_z


        mpm_init_pos = transformed_pos.to(device=self.device)
        mpm_init_vol = get_particle_volume(
            mpm_init_pos,
            self.material_params["n_grid"],
            self.material_params["grid_lim"] / self.material_params["n_grid"],
            uniform = self.material_params["material"] == "sand",
        ).to(device=self.device)
        init_cov = apply_cov_rotations(init_cov, rotation_matrices)
        init_cov = scale_origin * scale_origin * init_cov
        mpm_init_cov = torch.zeros((mpm_init_pos.shape[0], 6), device=self.device)
        mpm_init_cov[:self.gs_num] = init_cov

        if len(self.mpm_init_pos_list) == push_idx:
            self.mpm_init_pos_list.append(mpm_init_pos)
            self.mpm_init_vol_list.append(mpm_init_vol)
            self.mpm_init_cov_list.append(mpm_init_cov)

        # camera setting
        mpm_space_viewpoint_center = (
            torch.tensor(self.camera_params["mpm_space_viewpoint_center"]).reshape((1, 3)).cuda()
        )
        mpm_space_vertical_upward_axis = (
            torch.tensor(self.camera_params["mpm_space_vertical_upward_axis"])
            .reshape((1, 3))
            .cuda()
        )
        (
            self.viewpoint_center_worldspace,
            self.observant_coordinates,
        ) = get_center_view_worldspace_and_observant_coordinate(
            mpm_space_viewpoint_center,
            mpm_space_vertical_upward_axis,
            rotation_matrices,
            scale_origin,
            original_mean_pos,
        )

        # camera
        self.current_camera = get_camera_view(self.calibrate_path, self.meta_path, pipeline,cam_id=0)

        #rasterize
        self.rasterize = my_initialize_resterize_v2(self.current_camera)
        self.mpm_solver.load_initial_data_from_torch(
            self.mpm_init_pos_list[push_idx],
            self.mpm_init_vol_list[push_idx],
            self.mpm_init_cov_list[push_idx],
            n_grid=self.material_params["n_grid"],
            grid_lim=self.material_params["grid_lim"],
        )
        self.n_particle = self.mpm_solver.n_particles
        self.obj_num = self.n_particle
        self.mpm_solver.set_parameters_dict(self.material_params)
        if self.gripper_type == 'push':
            self.mpm_solver.set_primitive(self.eef_pos, self.frame_dt, self.start_frame, self.step_per_frame)
        set_boundary_conditions(self.mpm_solver, self.bc_params, self.time_params)

    def init_solver(self, push_idx, frame_now=None):
        self.mpm_solver.reset_pos_from_torch(self.mpm_init_pos_list[push_idx], self.mpm_init_vol_list[push_idx], self.mpm_init_cov_list[push_idx])
        self.current_frame = self.start_frame if frame_now is None else frame_now
        if self.gripper_type == 'push':
            self.mpm_solver.primitive.reset()

    def set_friction(self, friction):
        fri = wp.from_torch(friction)
        wp.launch(set_value_friction_array, dim=1, inputs=[self.mpm_solver.mpm_model.friction, fri], device=self.device)

    def wp_set_phys_property_E(self, logE):
        wp.launch(set_value_mask_logE, dim=self.mpm_solver.n_particles, inputs=[self.mpm_solver.mpm_model.E, wp.from_torch(logE), self.ee_mask], device=self.device)
        self.mpm_solver.finalize_mu_lam()

    def wp_set_phys_property_ED(self, logE, density):
        wp.launch(set_value_mask_logE, dim=self.mpm_solver.n_particles, inputs=[self.mpm_solver.mpm_model.E, wp.from_torch(logE), self.ee_mask], device=self.device)
        wp.launch(set_value_mask_density, dim=self.mpm_solver.n_particles,
                  inputs=[self.mpm_solver.mpm_state.particle_density, wp.from_torch(density), self.ee_mask],
                  device=self.device)
        # reset mass
        wp.launch(
            kernel=get_float_array_product,
            dim=self.n_particle,
            inputs=[
                self.mpm_solver.mpm_state.particle_density,
                self.mpm_solver.mpm_state.particle_vol,
                self.mpm_solver.mpm_state.particle_mass,
            ],
            device=self.device,
        )
        self.mpm_solver.finalize_mu_lam()

    def wp_set_phys_property_ED_single(self, logE, density):
        wp.launch(set_value_mask_logE_single, dim=self.mpm_solver.n_particles, inputs=[self.mpm_solver.mpm_model.E, wp.from_torch(logE), self.ee_mask], device=self.device)
        wp.launch(set_value_mask_density, dim=self.mpm_solver.n_particles,
                  inputs=[self.mpm_solver.mpm_state.particle_density, wp.from_torch(density), self.ee_mask],
                  device=self.device)
        # reset mass
        wp.launch(
            kernel=get_float_array_product,
            dim=self.n_particle,
            inputs=[
                self.mpm_solver.mpm_state.particle_density,
                self.mpm_solver.mpm_state.particle_vol,
                self.mpm_solver.mpm_state.particle_mass,
            ],
            device=self.device,
        )
        self.mpm_solver.finalize_mu_lam()

    def wp_set_phys_property_cloth_D(self, log_warp_stiffness, log_weft_stiffness, log_shear_stiffness, density):
        wp.launch(
            kernel=set_E_cloth_from_log_stiffness,
            dim=self.mpm_solver.n_particles,
            inputs=[
                self.mpm_solver.mpm_model.E_cloth,  # 目标数组
                wp.from_torch(log_warp_stiffness),  # 经向刚度的对数
                wp.from_torch(log_weft_stiffness),  # 纬向刚度的对数
                wp.from_torch(log_shear_stiffness),  # 剪切刚度的对数
                self.ee_mask  # 掩码数组
            ],
            device=self.device
        )
        wp.launch(set_value_mask_density, dim=self.mpm_solver.n_particles,
                  inputs=[self.mpm_solver.mpm_state.particle_density, wp.from_torch(density), self.ee_mask],
                  device=self.device)
        # reset mass
        wp.launch(
            kernel=get_float_array_product,
            dim=self.n_particle,
            inputs=[
                self.mpm_solver.mpm_state.particle_density,
                self.mpm_solver.mpm_state.particle_vol,
                self.mpm_solver.mpm_state.particle_mass,
            ],
            device=self.device,
        )
        self.mpm_solver.finalize_mu_lam()

    def wp_set_phys_property_cloth_D_single(self, log_warp_stiffness, log_weft_stiffness, log_shear_stiffness, density):
        wp.launch(
            kernel=set_E_cloth_from_log_stiffness_single,
            dim=self.mpm_solver.n_particles,
            inputs=[
                self.mpm_solver.mpm_model.E_cloth,
                wp.from_torch(log_warp_stiffness),
                wp.from_torch(log_weft_stiffness),
                wp.from_torch(log_shear_stiffness),
                self.ee_mask
            ],
            device=self.device
        )
        wp.launch(set_value_mask_density, dim=self.mpm_solver.n_particles,
                  inputs=[self.mpm_solver.mpm_state.particle_density, wp.from_torch(density), self.ee_mask],
                  device=self.device)
        # reset mass
        wp.launch(
            kernel=get_float_array_product,
            dim=self.n_particle,
            inputs=[
                self.mpm_solver.mpm_state.particle_density,
                self.mpm_solver.mpm_state.particle_vol,
                self.mpm_solver.mpm_state.particle_mass,
            ],
            device=self.device,
        )
        self.mpm_solver.finalize_mu_lam()

    def calculate_loss_gradient(self, no_grad=False):
        trans_pos = self.mpm_solver.export_particle_x_to_torch()
        pos = apply_inverse_rotations(
            undotransform2origin(
                undoshift2center111(trans_pos, self.shift_tensor), self.scale_origin, self.original_mean_pos
            ),
            self.rotation_matrices,
        )
        pos_gt = self.gaussian_gt['object_points'][self.current_frame, self.gaussian_gt['object_visibilities'][self.current_frame]]
        pos_pred = pos[:self.gaussian_gt['object_visibilities'].shape[1]][self.gaussian_gt['object_visibilities'][self.current_frame]]
        pos_num = pos_gt.shape[0]

        # plot_point_cloud_with_lines(pos_gt,pos_pred)
        means3d_gt = wp.from_torch(pos_gt, dtype=wp.vec3, requires_grad=False)
        means3d_pred = wp.from_torch(pos_pred, dtype=wp.vec3, requires_grad=False)

        distance_matrix = wp.zeros((pos_num, pos_num), requires_grad=False)
        neigh_indices = wp.zeros(pos_num, dtype=wp.int32, requires_grad=False)
        neigh_indices_pred_to_gt = wp.zeros(pos_num, dtype=wp.int32, requires_grad=False)
        ee_mask = wp.from_numpy(np.zeros(pos_num, dtype=np.int32),dtype=wp.int32)
        wp.launch(
            compute_distances,
            # dim=(self.obj_num, self.obj_num),
            dim=(pos_num,pos_num),
            inputs=[
                means3d_pred,
                means3d_gt,
                ee_mask,
            ],
            outputs=[distance_matrix],
        )
        wp.launch(
            compute_neigh_indices,
            # dim=self.obj_num,
            dim=pos_num,
            inputs=[distance_matrix],
            outputs=[neigh_indices],
        )
        wp.launch(
            compute_neigh_indices_inverse,
            dim=pos_num,
            inputs=[distance_matrix],
            outputs=[neigh_indices_pred_to_gt],
        )

        neigh_indices = wp.to_torch(neigh_indices)
        neigh_indices_pred_to_gt = wp.to_torch(neigh_indices_pred_to_gt)

        ee_mask = torch.zeros(pos_pred.shape[0], device=self.device)
        chamfer_loss = compute_chamfer_loss_torch_bidirectional(pos_pred, pos_gt, ee_mask, neigh_indices,
                                                                neigh_indices_pred_to_gt, self.chamfer_loss_weight)
        track_loss = compute_track_loss_torch(pos_pred, pos_gt, ee_mask, self.track_loss_weight)

        loss = chamfer_loss + track_loss
        if not no_grad:
            grad_x = torch.autograd.grad(
                loss,
                trans_pos,
                retain_graph=False,  # 设置为 True 如果后续还要再次计算梯度
                create_graph=False,  # 设置为 True 如果你后面还要对 grad_x 继续求导（一般 False）
                only_inputs=True
            )[0]
            grad_x = wp.from_torch(grad_x, dtype=wp.vec3f)
            wp.launch(set_value_vec, self.mpm_solver.n_particles, [self.mpm_solver.mpm_state.particle_x.grad, grad_x], device=self.device)
            wp.launch(sum_vec3, self.mpm_solver.n_particles, [self.mpm_solver.mpm_state.particle_x, grad_x],
                      [self.grad_loss], device=self.device)
        self.temp_loss = loss.detach()

    def save_params(self, log_E, density,friction, file_path):
        data = {'log_E': log_E.detach().cpu().numpy(),
                'density': density.detach().cpu().numpy(),
                'friction': friction.detach().cpu().numpy()}  # 将 torch 向量转换为 numpy 数组并保存
        with open(file_path, 'wb') as f:
            pkl.dump(data, f)

    def save_params_cloth(self, log_warp_stiffness, log_weft_stiffness, log_shear_stiffness, density,friction, file_path):
        data = {'log_warp_stiffness': log_warp_stiffness.detach().cpu().numpy(),
                'log_weft_stiffness': log_weft_stiffness.detach().cpu().numpy(),
                'log_shear_stiffness': log_shear_stiffness.detach().cpu().numpy(),
                'density': density.detach().cpu().numpy(),
                'friction': friction.detach().cpu().numpy()}  # 将 torch 向量转换为 numpy 数组并保存
        with open(file_path, 'wb') as f:
            pkl.dump(data, f)

    def load_params(self, params_path):
        with open(params_path, "rb") as f:
            data = pkl.load(f)
        log_E_tensor = torch.tensor(data["log_E"], device=self.device)
        density_tensor = torch.tensor(data["density"], device=self.device)
        friction_tensor = torch.tensor(data["friction"], device=self.device)
        self.wp_set_phys_property_ED(log_E_tensor, density_tensor)
        self.set_friction(friction_tensor)

    def load_params_cloth(self, params_path):
        with open(params_path, "rb") as f:
            data = pkl.load(f)
        log_warp_stiffness_tensor = torch.tensor(data["log_warp_stiffness"], device=self.device)
        log_weft_stiffness_tensor = torch.tensor(data["log_weft_stiffness"], device=self.device)
        log_shear_stiffness = torch.tensor(data["log_shear_stiffness"], device=self.device)
        density_tensor = torch.tensor(data["density"], device=self.device)
        friction_tensor = torch.tensor(data["friction"], device=self.device)
        self.wp_set_phys_property_cloth_D(log_warp_stiffness_tensor, log_weft_stiffness_tensor, log_shear_stiffness, density_tensor)
        self.set_friction(friction_tensor)

    def clear_loss(self):
        self.l2_loss.zero_()
        self.chamfer_loss.zero_()
        self.track_loss.zero_()
        self.grad_loss.zero_()
        self.loss.zero_()
        self.chamfer_loss_total.zero_()
        self.track_loss_total.zero_()

    def get_current_state(self, save_path, epoch,
                          draw_stress=False, draw_F=False, draw_gradient=None, draw_gt=False, draw_primitive = None, draw_E=False, mask=None):
        trans_pos = self.mpm_solver.export_particle_x_to_torch()[:self.gs_num].to(self.device)
        trans_cov3D = self.mpm_solver.export_particle_cov_to_torch()
        trans_rot = self.mpm_solver.export_particle_R_to_torch()
        cov3D = trans_cov3D.view(-1, 6)[:self.gs_num].to(self.device)
        rot = trans_rot.view(-1, 3, 3)[:self.gs_num].to(self.device)
        pos = apply_inverse_rotations(
            undotransform2origin(
                undoshift2center111(trans_pos, self.shift_tensor), self.scale_origin, self.original_mean_pos
            ),
            self.rotation_matrices,
        )
        cov3D = cov3D / (self.scale_origin * self.scale_origin)
        cov3D = apply_inverse_cov_rotations(cov3D, self.rotation_matrices)
        opacity = self.opacity
        shs = self.shs
        camera_center = self.current_camera.campos
        colors_precomp = convert_SH_v2(shs, self.current_camera, self.active_sh_degree, self.max_sh_degree, pos, rot,
                                       camera_center) if self.colors_precomp is None else torch.tile(self.colors_precomp, (self.n_particle, 1)).to(self.device)
        color_map = {'r': 0, 'g': 1, 'b': 2}
        save_name = f'{epoch}'
        init_screen_points = self.init_screen_points
        select_num = 2000
        if draw_stress:
            colors = torch.zeros_like(colors_precomp, device=self.device)
            # red
            S = wp.to_torch(self.mpm_solver.mpm_state.particle_stress)
            magnitudes = torch.norm(S, p='fro', dim=(1, 2))[:self.obj_num]
            max_idx = magnitudes.sort(descending=True).indices[:select_num]
            # magnitudes = torch.log(magnitudes) + 10
            # magnitudes = (magnitudes - magnitudes.min()) / (magnitudes.max() - magnitudes.min())
            colors[:, color_map['b']] = 1

            colors[max_idx, color_map['r']] = 1
            colors[max_idx, color_map['b']] = 0
            render_color = colors
            save_name = os.path.join(save_name, 'stress_norm')
        elif draw_F:
            colors = torch.zeros_like(colors_precomp, device=self.device)
            # red
            F = wp.to_torch(self.mpm_solver.mpm_state.particle_F_trial)
            magnitudes = torch.det(F)[:self.obj_num]
            max_idx = magnitudes.sort(descending=True).indices[:select_num]

            colors[:, color_map['r']] = (magnitudes - magnitudes.min()) / (magnitudes.max() - magnitudes.min())
            # colors[max_idx, color_map['b']] = 0
            render_color = colors
            save_name = os.path.join(save_name, 'F_det')

        elif draw_E:
            visual_E = wp.to_torch(self.mpm_solver.mpm_model.E).detach()
            colors = torch.zeros_like(colors_precomp, device=self.device)
            magnitudes = torch.log(visual_E)
            if magnitudes.max() - magnitudes.min() > 0.0:
                magnitudes = (magnitudes - magnitudes.min()) / (magnitudes.max() - magnitudes.min())
            else:
                magnitudes[:] = 1.0
            # max_idx = magnitudes.sort(descending=True).indices[:select_num]
            # magnitudes = torch.log(magnitudes) + 10
            # magnitudes = (magnitudes - magnitudes.min()) / (magnitudes.max() - magnitudes.min())
            colors[:self.obj_num, color_map['r']] = magnitudes
            render_color = colors
            save_name = os.path.join(save_name, 'E_visual')

        elif draw_gradient is not None:
            colors = torch.zeros_like(colors_precomp, device=self.device)
            magnitudes = draw_gradient[:self.obj_num]
            max_idx = magnitudes.sort(descending=True).indices[:select_num]
            # magnitudes = torch.log(magnitudes) + 10
            # magnitudes = (magnitudes - magnitudes.min()) / (magnitudes.max() - magnitudes.min())
            colors[:, color_map['b']] = 1
            colors[max_idx, color_map['r']] = 1
            colors[max_idx, color_map['b']] = 0
            render_color = colors
            save_name = os.path.join(save_name, 'gradient')
        elif mask is not None:
            colors_precomp[mask == 1] = torch.tensor([1.0, 0.0, 0.0], device=self.device)[None]
            render_color = colors_precomp
        else:
            render_color = colors_precomp

        if draw_gt:
            offset_frame = 0
            temp_cur = min(self.current_frame,self.end_frame)
            gt_means3D = self.gaussian_gt['object_points'][temp_cur + offset_frame]
            gt_means2D = torch.zeros_like(self.gaussian_gt['object_points'][temp_cur + offset_frame],
                                          requires_grad=False, device=self.device)
            gt_colors_precomp = self.gaussian_gt['object_colors'][temp_cur + offset_frame]
            gt_opacities = torch.sigmoid(self.gaussian_gt['logit_opacities'])
            gt_scales = torch.exp(self.gaussian_gt['log_scales'])
            gt_rot = torch.nn.functional.normalize(
                self.gaussian_gt['unnorm_rotations'][temp_cur + offset_frame])
            gt_cov = RT_to_cov(gt_rot, gt_scales)
            gt_colors_precomp[:, :] = 1.0
            pos = torch.cat([pos, gt_means3D], dim=0)
            init_screen_points = torch.cat([init_screen_points, gt_means2D], dim=0)
            render_color = torch.cat([render_color, gt_colors_precomp], dim=0)
            opacity = torch.cat([opacity, gt_opacities], dim=0)
            cov3D = torch.cat([cov3D, gt_cov], dim=0)

        if draw_primitive is not None:
            xyz_ee = draw_primitive[0]
            opacities_ee = draw_primitive[1]
            scales_ee = draw_primitive[2]
            rots_ee = draw_primitive[3]
            precomp_colors_ee = draw_primitive[4]

            opacities_ee = torch.sigmoid(opacities_ee)
            scales_ee = torch.exp(scales_ee)
            rots_ee = torch.nn.functional.normalize(rots_ee)

            xyz_ee = apply_inverse_rotations(
                undotransform2origin(
                    undoshift2center111(xyz_ee, self.shift_tensor), self.scale_origin, self.original_mean_pos
                ),
                self.rotation_matrices,
            )
            ee_cov = RT_to_cov(rots_ee, scales_ee)

            ee_num = xyz_ee.shape[0]
            ee_means_2D = torch.zeros((ee_num, 3), requires_grad=False, device=self.device)

            pos = torch.cat([pos, xyz_ee], dim=0)
            init_screen_points = torch.cat([init_screen_points, ee_means_2D], dim=0)
            render_color = torch.cat([render_color, precomp_colors_ee], dim=0)
            opacity = torch.cat([opacity, opacities_ee], dim=0)
            cov3D = torch.cat([cov3D, ee_cov], dim=0)

        rendering, raddi = self.rasterize(
            means3D=pos,
            means2D=init_screen_points,
            shs=None,
            colors_precomp=render_color,
            opacities=opacity,
            scales=None,
            rotations=None,
            cov3D_precomp=cov3D,
        )
        cv2_img = rendering.permute(1, 2, 0).detach().cpu().numpy()
        cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
        output_path = os.path.join(save_path, save_name)
        os.makedirs(output_path, exist_ok=True)
        frame_id = self.current_frame
        cv2.imwrite(
            os.path.join(output_path, f"{frame_id}.png".rjust(8, "0")),
            255 * cv2_img,
        )

    def save_current_phys_state(self):
        self.saved_x = self.mpm_solver.export_particle_x_to_torch().clone()
        self.saved_v = self.mpm_solver.export_particle_v_to_torch().clone()
        self.saved_c = self.mpm_solver.export_particle_C_to_torch().clone().view(-1, 3, 3)[:self.gs_num]
        self.saved_f = self.mpm_solver.export_particle_F_trial_to_torch().clone().view(-1, 3, 3)[:self.gs_num]
        self.saved_covs = self.mpm_solver.export_particle_cov_to_torch().clone().view(-1, 6)[:self.gs_num]
        self.saved_vols = self.mpm_solver.export_particle_vol_to_torch().clone()

    def load_current_phys_state(self):
        self.mpm_solver.reset_states_from_torch(self.saved_x.detach(),
                                                self.saved_vols.detach(),
                                                self.saved_v.detach(),
                                                self.saved_f.detach(),
                                                self.saved_c.detach(),
                                                self.saved_covs.detach())

    def finalize_phys(self):
        self.mpm_solver.finalize_mu_lam()

    def extract_structure_points(self):
        trans_pos = self.mpm_solver.export_particle_x_to_torch()[:self.gs_num]
        structure_trans_pos = trans_pos[:self.structure_num]
        pos = apply_inverse_rotations(
            undotransform2origin(
                undoshift2center111(structure_trans_pos, self.shift_tensor), self.scale_origin, self.original_mean_pos
            ),
            self.rotation_matrices,
        )
        return pos.cpu().numpy()

    def step_push(self, last_frame_grad=False):
        for step in range(self.step_per_frame):
            if step >= self.step_per_frame - 10 and last_frame_grad:
                with self.tape:
                    self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device, push=True)
            else:
                self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device, push=True)
        ee_primitive = self.mpm_solver.primitive.visual_3dgs
        save_path = os.path.join(self.output_path, os.path.basename(self.data_path), 'temp')
        self.get_current_state(save_path=save_path, epoch=0, draw_primitive=ee_primitive, draw_gt=True)
        if self.current_frame < self.end_frame:
            self.current_frame += 1
            self.current_substep = 0

    def single_gripper_init(self):
        k = self.eef_k
        if self.preprocessing_params['eef_only']:
            gripper_point = self.eef_pos[0].clone()
            # gripper_point[-1] = self.eef_z_mean  # 固定z坐标
            ee_mask_numpy = np.array([0] * self.obj_num)
            positions = self.mpm_init_pos_list[0]  # 所有物体位置
            # 计算夹爪点到所有物体位置的欧氏距离
            distances = torch.cdist(gripper_point.unsqueeze(0), positions).squeeze(0)
            # 获取距离最近的100个点的索引
            _, indices = torch.topk(distances, k=k, largest=False)
            ee_mask_numpy[indices.cpu().numpy()] = 1
            self.gripper_mask_wp = wp.from_numpy(ee_mask_numpy, dtype=wp.int32)
            self.gripper_mask = torch.from_numpy(ee_mask_numpy)
            self.velocity = (self.eef_pos[1:] - self.eef_pos[:-1]) / self.frame_dt
            self.velocity = smooth_positions_torch(self.velocity, window_size=7) *0.9
        else:
            gripper_point = self.eef_pos[0].clone()
            ee_mask_numpy = np.array([0] * self.obj_num)
            positions = self.mpm_init_pos_list[0][:self.gaussian_gt['logit_opacities'].shape[0]]
            object_motions_valid = torch.tensor(self.gaussian_gt['object_motions_valid'], device='cuda',
                                                dtype=torch.bool)
            valid_points_mask = torch.all(object_motions_valid[:-1] == 1, dim=0)
            valid_point_indices = torch.nonzero(valid_points_mask).flatten()
            # 筛选出有效的位置点
            valid_positions = positions[valid_point_indices]
            # 计算夹爪点到所有有效物体位置的欧氏距离
            distances = torch.cdist(gripper_point.unsqueeze(0), valid_positions).squeeze(0)
            # 获取距离最近的 k 个点在 valid_positions 中的索引
            _, valid_indices = torch.topk(distances, k=k, largest=False)
            # 获取原 selected_positions 中的索引
            indices = valid_point_indices[valid_indices]
            ee_mask_numpy[indices.cpu().numpy()] = 1
            self.gripper_mask_wp = wp.from_numpy(ee_mask_numpy, dtype=wp.int32)
            self.gripper_mask = torch.from_numpy(ee_mask_numpy)
            self.target_point_indice = indices[0]
            positions = (self.gaussian_gt['object_points'][:, self.target_point_indice]).clone()
            self.velocity = ((positions[1:] - positions[:-1]) * self.scale_origin) / self.frame_dt
            self.velocity = smooth_positions_torch(self.velocity, window_size=7)

    def step_single_gripper(self, last_frame_grad=False):
        if self.current_frame == self.start_frame:
            self.single_gripper_init()
        self.mpm_solver.particle_velocity_modifier_params[0].velocity = wp.vec3(
            self.velocity[self.current_frame, 0].item(),
            self.velocity[self.current_frame, 1].item(),
            self.velocity[self.current_frame, 2].item())
        self.mpm_solver.particle_velocity_modifier_params[0].mask = self.gripper_mask_wp
        for step in range(self.step_per_frame):
            if step >= self.step_per_frame - 10 and last_frame_grad:
                with self.tape:
                    self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device, single_gripper=True)
            else:
                self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device, single_gripper=True)
        self.current_frame += 1
        save_path = os.path.join(self.output_path, os.path.basename(self.data_path), 'temp')
        self.get_current_state(save_path=save_path, epoch=0, mask=self.gripper_mask, draw_gt=True)

    def double_gripper_init(self):
        k = self.eef_k
        ee_mask_numpy_1 = np.array([0] * self.obj_num)
        ee_mask_numpy_2 = np.array([0] * self.obj_num)
        if self.preprocessing_params['eef_only']:
            gripper_point_1 = self.eef_pos[0, 0].clone()
            positions = self.mpm_init_pos_list[0].clone()  # 所有物体位置
            # 计算夹爪点到所有物体位置的欧氏距离
            distances_1 = torch.cdist(gripper_point_1.unsqueeze(0), positions).squeeze(0)
            # 获取距离最近的100个点的索引
            _, indices_1 = torch.topk(distances_1, k=k, largest=False)

            gripper_point_2 = self.eef_pos[1, 0].clone()
            # 计算夹爪点到所有物体位置的欧氏距离
            distances_2 = torch.cdist(gripper_point_2.unsqueeze(0), positions).squeeze(0)
            # 获取距离最近的100个点的索引
            _, indices_2 = torch.topk(distances_2, k=k, largest=False)
            ee_mask_numpy_1[indices_1.cpu().numpy()] = 1
            ee_mask_numpy_2[indices_2.cpu().numpy()] = 1
            self.gripper_mask_wp_1 = wp.from_numpy(ee_mask_numpy_1, dtype=wp.int32)
            self.gripper_mask_wp_2 = wp.from_numpy(ee_mask_numpy_2, dtype=wp.int32)
            self.gripper_mask = torch.from_numpy(ee_mask_numpy_1 | ee_mask_numpy_2)
            # positions_1 = smooth_positions_torch(self.eef_pos[0].clone(), window_size=10)
            # positions_2 = smooth_positions_torch(self.eef_pos[1].clone(), window_size=10)
            positions_1 = self.eef_pos[0]
            positions_2 = self.eef_pos[1]
            self.velocity_1 = (positions_1[1:] - positions_1[:-1]) / self.frame_dt
            self.velocity_2 = (positions_2[1:] - positions_2[:-1]) / self.frame_dt
            self.velocity_1 = smooth_positions_torch(self.velocity_1, window_size=4)
            self.velocity_2 = smooth_positions_torch(self.velocity_2, window_size=4)
        else:
            gripper_point_1 = self.eef_pos[0, 0].clone()
            positions = self.mpm_init_pos_list[0][:self.gaussian_gt['logit_opacities'].shape[0]]
            object_motions_valid = torch.tensor(self.gaussian_gt['object_motions_valid'], device='cuda',
                                                dtype=torch.bool)
            valid_points_mask = torch.all(object_motions_valid[:-1] == 1, dim=0)
            valid_point_indices = torch.nonzero(valid_points_mask).flatten()
            valid_positions = positions[valid_point_indices]
            # 计算夹爪点到所有有效物体位置的欧氏距离
            distances_1 = torch.cdist(gripper_point_1.unsqueeze(0), valid_positions).squeeze(0)
            # 获取距离最近的 k 个点在 valid_positions 中的索引
            _, valid_indices_1 = torch.topk(distances_1, k=k, largest=False)
            # 获取原 selected_positions 中的索引
            indices_1 = valid_point_indices[valid_indices_1]

            gripper_point_2 = self.eef_pos[1, 0].clone()
            # gripper_point_2[-1] = self.eef_z_mean  # 固定z坐标
            # 计算夹爪点到所有有效物体位置的欧氏距离
            distances_2 = torch.cdist(gripper_point_2.unsqueeze(0), valid_positions).squeeze(0)
            # 获取距离最近的 k 个点在 valid_positions 中的索引
            _, valid_indices_2 = torch.topk(distances_2, k=k, largest=False)
            # 获取原 selected_positions 中的索引
            indices_2 = valid_point_indices[valid_indices_2]

            ee_mask_numpy_1[indices_1.cpu().numpy()] = 1
            ee_mask_numpy_2[indices_2.cpu().numpy()] = 1
            self.gripper_mask_wp_1 = wp.from_numpy(ee_mask_numpy_1, dtype=wp.int32)
            self.gripper_mask_wp_2 = wp.from_numpy(ee_mask_numpy_2, dtype=wp.int32)
            self.gripper_mask = torch.from_numpy(ee_mask_numpy_1 | ee_mask_numpy_2)
            self.target_point1_indice = indices_1[0]
            self.target_point2_indice = indices_2[0]
            positions_1 = self.gaussian_gt['object_points'][:, self.target_point1_indice]
            positions_2 = self.gaussian_gt['object_points'][:, self.target_point2_indice]
            positions_1 = smooth_positions_torch(positions_1, window_size=7)
            positions_2 = smooth_positions_torch(positions_2, window_size=7)
            self.velocity_1 = ((positions_1[1:] - positions_1[
                                                  :-1]) * self.scale_origin) / self.frame_dt
            self.velocity_2 = ((positions_2[1:] - positions_2[
                                                  :-1]) * self.scale_origin) / self.frame_dt

    def step_double_gripper(self, last_frame_grad=False):
        if self.current_frame == self.start_frame:
            self.double_gripper_init()
        self.mpm_solver.particle_velocity_modifier_params[0].velocity = wp.vec3(
            self.velocity_1[self.current_frame, 0].item(),
            self.velocity_1[self.current_frame, 1].item(),
            self.velocity_1[self.current_frame, 2].item())
        self.mpm_solver.particle_velocity_modifier_params[0].mask = self.gripper_mask_wp_1

        self.mpm_solver.particle_velocity_modifier_params[1].velocity = wp.vec3(
            self.velocity_2[self.current_frame, 0].item(),
            self.velocity_2[self.current_frame, 1].item(),
            self.velocity_2[self.current_frame, 2].item())
        self.mpm_solver.particle_velocity_modifier_params[1].mask = self.gripper_mask_wp_2
        for step in range(self.step_per_frame):
            if step >= self.step_per_frame - 10 and last_frame_grad:
                with self.tape:
                    self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device, double_gripper=True)
            else:
                self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device, double_gripper=True)
        self.current_frame += 1
        save_path = os.path.join(self.output_path, os.path.basename(self.data_path), 'temp')
        self.get_current_state(save_path=save_path, epoch=0, mask=self.gripper_mask, draw_gt=True)

    def pre_move_save_state(self, push_idx):
        print('pre_move')
        # self.set_ee_velocity(torch.zeros_like(self.eef_pos[0], device=self.device))
        for pre_frame in tqdm(range(self.pre_move_frame)):
            for step in range(self.step_per_frame):
                self.mpm_solver.p2g2p(step, self.substep_dt, device=self.device)
        pos = self.mpm_solver.export_particle_x_to_torch()
        cov = self.mpm_solver.export_particle_cov_to_torch().reshape(-1, 6)
        vol = self.mpm_solver.export_particle_vol_to_torch()
        self.mpm_init_pos_list[push_idx] = pos.clone()
        self.mpm_init_cov_list[push_idx] = cov.clone()
        self.mpm_init_vol_list[push_idx] = vol.clone()
        self.eef_z = pos[:, -1].max().item()
        self.eef_z_mean = pos[:, -1].mean().item()