import os.path

import matplotlib.pyplot as plt
import numpy as np
import torch

from finetune_utils import *
import random
from dgl.geometry import farthest_point_sampler
from data.utils import fps_rad_idx_torch
from tqdm import tqdm
from data.gripper_dataset_2 import construct_edges_from_states
from data.fix_idx_dataset import construct_edges_from_states_sep
from gnn.model import DynamicsPredictorMyMultiLayer
from chamferdist import ChamferDistance
from render.phystwin_LBS import interpolate_motions as inter_motion
from render.phystwin_LBS import interpolate_motions_grad as inter_motion_grad
from render.phystwin_LBS import knn_weights_new, get_topk_indices
import math
import argparse
from pytorch3d.loss import chamfer_distance
import json
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def set_all_seeds(seed, save_state=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(42)

def draw_points_2(w2c, k, points, color=(0, 255, 0), radius=1):
    # Project xyz_gt to 2D image coordinates
    # im = np.ones((720, 1280, 3), dtype=np.uint8)
    im = np.full((720, 1280, 3), 255, dtype=np.uint8)
    xyz_gt_t = points.cpu().numpy() if isinstance(points, torch.Tensor) else points
    xyz_gt_hom = np.hstack((xyz_gt_t, np.ones((xyz_gt_t.shape[0], 1))))  # Convert to homogeneous coordinates

    # Transform to camera coordinates
    cam_coords = (w2c @ xyz_gt_hom.T).T

    # Project to image plane
    fx, fy, cx, cy = k[0, 0], k[1, 1], k[0, 2], k[1, 2]
    img_coords = []
    for pt in cam_coords:
        if pt[2] > 0:  # Only project points in front of camera
            u = fx * (pt[0] / pt[2]) + cx
            v = fy * (pt[1] / pt[2]) + cy
            img_coords.append((int(u), int(v)))

    # Draw projected points on the image
    for pt in img_coords:
        cv2.circle(im, pt, radius=radius, color=color, thickness=-1)  # Green dots for xyz_gt
    return im


def draw_points_3(w2c, k, points, color=(30, 144, 255), radius=3, alpha=0.7,
                  add_glow=True, background_color=(255, 255, 255)):
    """
    美观的点云投影绘制函数

    参数:
        w2c: 世界到相机的变换矩阵
        k: 相机内参矩阵
        points: 点云数据 (Nx3)
        color: 点颜色 (BGR格式)
        radius: 点半径
        alpha: 透明度 (0-1)
        add_glow: 是否添加光晕效果
        background_color: 背景颜色 (BGR格式)
    """
    # 创建画布
    im = np.full((720, 1280, 3), background_color, dtype=np.uint8)

    # 转换为numpy数组并添加齐次坐标
    xyz_gt_t = points.cpu().numpy() if isinstance(points, torch.Tensor) else points
    xyz_gt_hom = np.hstack((xyz_gt_t, np.ones((xyz_gt_t.shape[0], 1))))

    # 变换到相机坐标系并投影
    cam_coords = (w2c @ xyz_gt_hom.T).T
    fx, fy, cx, cy = k[0, 0], k[1, 1], k[0, 2], k[1, 2]

    # 收集所有有效投影点
    img_coords = []
    depths = []  # 用于深度着色
    for pt in cam_coords:
        if pt[2] > 0:  # 只投影相机前方的点
            u = fx * (pt[0] / pt[2]) + cx
            v = fy * (pt[1] / pt[2]) + cy
            img_coords.append((int(u), int(v)))
            depths.append(pt[2])

    if not img_coords:
        return im  # 如果没有有效点，返回空白图像

    # 归一化深度用于颜色变化
    depths = np.array(depths)
    if len(depths) > 1:
        depths = (depths - depths.min()) / (depths.max() - depths.min())

    # 创建临时图层用于透明度混合
    overlay = im.copy()

    # 绘制每个点
    for i, pt in enumerate(img_coords):
        # 根据深度调整颜色（可选）
        if len(depths) > 1:
            depth_color = tuple(int(c * (0.5 + 0.5 * depths[i])) for c in color)
        else:
            depth_color = color

        # 绘制主点
        cv2.circle(overlay, pt, radius=radius, color=depth_color, thickness=-1)

        # 添加光晕效果（可选）
        if add_glow:
            glow_radius = radius + 2
            glow_color = tuple(min(c + 60, 255) for c in depth_color)
            cv2.circle(overlay, pt, radius=glow_radius, color=glow_color, thickness=-1)

    # 应用透明度混合
    cv2.addWeighted(overlay, alpha, im, 1 - alpha, 0, im)

    # 可选：添加抗锯齿效果（通过缩放实现）
    if radius > 2:
        im = cv2.resize(im, None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR)
        im = cv2.resize(im, (1280, 720), interpolation=cv2.INTER_AREA)

    return im

def load_split(data_dir):
    with open(f"{data_dir}/split.json", "r") as f:
        split = json.load(f)
    frame_len = split["frame_len"]
    train_frame = split["train"][1]
    test_frame = split["test"][1]
    return frame_len, train_frame, test_frame

def smooth_positions_torch(positions, window_size=7):
    positions = positions.cpu().numpy()
    pad_size = window_size // 2
    padded_positions = np.pad(positions, ((pad_size, pad_size), (0, 0)), mode='edge')
    smoothed_positions = np.zeros_like(positions)
    for i in range(len(positions)):
        smoothed_positions[i] = np.mean(padded_positions[i:i + window_size], axis=0)
    smoothed_positions = torch.tensor(smoothed_positions, device='cuda', dtype=torch.float32)
    return smoothed_positions

def pad_torch(x, max_dim, dim=0):
    if dim == 0:
        x_dim = x.shape[0]
        x_pad = torch.zeros((max_dim, x.shape[1]), dtype=x.dtype, device=x.device)
        x_pad[:x_dim] = x
    elif dim == 1:
        x_dim = x.shape[1]
        x_pad = torch.zeros((x.shape[0], max_dim, x.shape[2]), dtype=x.dtype, device=x.device)
        x_pad[:, :x_dim] = x
    return x_pad

is_train = True
class PhysFinetuner:
    def __init__(self, data_path, config_path, ori_data_path, ply_path, name, mpm_test_path=None, save_path='../debug_view'):
        self.data_path = data_path
        self.config_path = config_path
        self.name = name
        self.is_push = False
        if name in ['single_push_rope', 'single_push_rope_1', 'single_push_rope_4']:
            self.is_push = True
        self.smooth_eef = False
        with open(os.path.join(data_path, 'test_data.pkl'), 'rb') as f:
            data = pkl.load(f)

        # with open(os.path.join(mpm_test_path, 'test_data.pkl'), 'rb') as f:
        #     mpm_test_data = pkl.load(f)

        self.structure_points = torch.from_numpy(data['structure_points']).to(torch.float32).to(device)
        # self.structure_points = torch.from_numpy(mpm_test_data['structure_points']).to(torch.float32).to(device)

        # eef_pos_mpm = mpm_test_data['eef_pos']
        # self.eef_pos = eef_pos_mpm

        self.eef_pos = data['eef_pos']
        # self.eef_pos[:] = self.eef_pos[0][None]
        # d = eef_pos_mpm[:, 0] - self.eef_pos[:, 0]
        # self.eef_pos += d[:, None, :]

        self.log_E = np.log(data['E'])
        # self.log_E = np.log(mpm_test_data['E'])

        friction_np = np.array(data['friction'])
        # friction_np = np.array(mpm_test_data['friction'])

        self.structure_num = len(self.structure_points)

        # #TODO UNSEEN INTERACTION
        # self.eef_pos[:, :-1] = self.structure_points[:, :-1].mean(axis=0)

        # self.xyz_0 = torch.from_numpy(data['xyz_0']).to(torch.float32).to(device)
        # fps_idx_1000 = np.load(os.path.join(f'../mpm_data/phystwin_data/{name}/0/', 'fps_idx_pre.npy'))

        fps_idx_1000 = np.load(os.path.join(f'../fps_idx_pre/{name}.npy'))
        self.fps_idx_1000 = fps_idx_1000

        # fps_idx_1000 = np.load(os.path.join('/data/dev/gs-dyn/mpm_data/fps_1000_idx', f'{name}_fps_idx.npy'))
        self.xyz_0 = self.structure_points[fps_idx_1000]
        xyz_gt = data['xyz_gt']
        visibility = data['visibility']
        vis_indices = []
        for i in range(len(visibility)):
            vis_idx = torch.from_numpy(np.where(visibility[i] == 1)[0]).to(device)
            vis_indices.append(vis_idx)
        for i in range(len(xyz_gt)):
            xyz_gt[i] = torch.from_numpy(xyz_gt[i]).to(torch.float32).to(device)
        self.vis_indices = vis_indices
        self.xyz_gt = xyz_gt


        #DEBUG
        # self.eef_pos[:] = self.eef_pos[0]

        eef_num = 1 if len(self.eef_pos.shape) == 2 else 2
        self.eef_num = eef_num
        if eef_num == 1:
            if self.smooth_eef:
                self.eef_pos_tensor = smooth_positions_torch(torch.from_numpy(self.eef_pos).to(device))[:, None, :]
            else:
                self.eef_pos_tensor = torch.from_numpy(self.eef_pos).to(device)[:, None, :]
        else:
            if self.smooth_eef:
                eef_pos_1 = smooth_positions_torch(torch.from_numpy(self.eef_pos).to(device)[0])
                eef_pos_2 = smooth_positions_torch(torch.from_numpy(self.eef_pos).to(device)[1])
                self.eef_pos_tensor = torch.stack([eef_pos_1, eef_pos_2], dim=0)
            else:
                self.eef_pos_tensor = torch.from_numpy(self.eef_pos).to(device)
        self.frame_num = len(self.xyz_gt)
        # optim_params_path = f'../optim_stage2/{name}/best_params.pkl'
        # with open(optim_params_path, 'rb') as f:
        #     optim_params = pkl.load(f)

        # self.log_E = optim_params['log_E'][:self.structure_num]
        # # self.log_E[:] = math.log(20000)
        # self.friction = torch.tensor([optim_params['friction']]).to(device).requires_grad_(True)

        # self.log_E[:] = self.log_E.mean()
        self.friction = torch.tensor(friction_np).to(device).requires_grad_(True)
        # self.normalized_E = torch.from_numpy(np.log(self.E) / 14.0).to(device).requires_grad_(True)

        self.normalized_E = torch.from_numpy(self.log_E / 14.0).to(device)[fps_idx_1000].requires_grad_(True)
        # self.log_E[:] = 9
        # self.normalized_E = torch.from_numpy(self.log_E / 14.0).to(device)[fps_idx_1000].requires_grad_(True)
        with open(config_path, 'r') as f:
            self.config = yaml.load(f, Loader=yaml.CLoader)
        dataset_config = self.config['dataset_config']['datasets'][0]


        # self.dm = DynamicsModule(self.config, 'latest', device)
        self.episode_path = os.path.join(save_path, str(0))
        self.init_config(self.config, 'latest', device)

        #LBS可视化
        ply = os.path.join(ply_path, 'point_cloud/iteration_10000/point_cloud.ply')
        self.gs_params = load_3dgs_model(ply)
        self.renderer = Renderer(device)
        camera_path = ori_data_path
        calibrate_path = os.path.join(camera_path, 'calibrate.pkl')
        meta_path = os.path.join(camera_path, 'metadata.json')
        self.w2c, self.k = my_get_camera_view_phystwin(calibrate_path, meta_path, cam_id=0)
        self.save_root = f'../unseen_interaction/{name}' if is_train else f'../finetune_E/{name}_no_train'
        os.makedirs(self.save_root, exist_ok=True)

        idx_data_path = f'../idx_data/{name}/train'
        fps_idx_1_path = os.path.join(idx_data_path, 'fps_idx_1.npy')
        fps_idx_2_path = os.path.join(idx_data_path, 'fps_idx_2.npy')
        Rs_path = os.path.join(idx_data_path, 'Rs.npy')
        Rr_path = os.path.join(idx_data_path, 'Rr.npy')
        rels_path = os.path.join(idx_data_path, 'rels.npy')
        nearest_indices_path = os.path.join(idx_data_path, 'nearest_indices.npy')
        self.fps_all_idx = np.load(fps_idx_1_path)
        self.fps_idx_second = np.load(fps_idx_2_path)
        self.Rr = torch.from_numpy(np.load(Rr_path)).to(device)
        self.Rs = torch.from_numpy(np.load(Rs_path)).to(device)
        self.rels = torch.from_numpy(np.load(rels_path)).to(device)
        if os.path.exists(nearest_indices_path):
            self.nearest_indices = torch.from_numpy(np.load(nearest_indices_path)).to(device)
        else:
            fps_points = self.xyz_0[self.fps_all_idx][self.fps_idx_second]
            dists_sq = torch.sum((fps_points.unsqueeze(1) - self.xyz_0.unsqueeze(0)) ** 2, dim=-1)
            k = 1000 // len(fps_points)
            _, self.nearest_indices = torch.topk(dists_sq, k=k, dim=1, largest=False, sorted=True)

        self.particle_pos_0 = None

        #优化器相关
        # self.optimizer = torch.optim.SGD([self.normalized_E, self.friction], lr=0.1)
        # self.optimizer = torch.optim.SGD([{'params': self.normalized_E, 'lr': 0.01},
        #                                   {'params': self.friction, 'lr': 0.001},])
        E_lr = 0.0001 if is_train else 0.0
        friction_lr = 0.0001 if is_train else 0.0
        self.optimizer = torch.optim.Adam([{'params': self.normalized_E, 'lr': E_lr},
                                          {'params': self.friction, 'lr': friction_lr}])
        # self.optimizer = torch.optim.SGD([{'params': self.normalized_E, 'lr': 1.0},
        #                                   {'params': self.friction, 'lr': 0.1},])
        self.chamfer_loss = ChamferDistance()
        self.track_loss = torch.nn.MSELoss()
        self.chamfer_weight = 1.0
        self.track_weight = 0.1
        # self.frame_len, _, self.train_frame = load_split(ply_path)
        self.frame_len, self.train_frame, self.test_frame = load_split(ori_data_path)

        self.epoch = 150
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=self.epoch)

    def load_model(self, train_config, model_config, checkpoint_dir, device):
        model_config['n_his'] = train_config['n_his']
        model = DynamicsPredictorMyMultiLayer(model_config, device)
        model.to(device)
        model.eval()
        model.load_state_dict(torch.load(checkpoint_dir))
        return model

    def init_config(self, config, epoch, device):
        self.device = device
        train_config = config['train_config']
        model_config = config['model_config']
        if epoch == 'latest':
            checkpoint_dir = os.path.join(train_config['out_dir'], 'checkpoints', 'latest.pth')
        else:
            checkpoint_dir = os.path.join(train_config['out_dir'], 'checkpoints', 'model_{}.pth'.format(epoch))
        self.model = self.load_model(train_config, model_config, checkpoint_dir, self.device)
        self.n_his = train_config['n_his']
        self.dist_thresh = train_config['dist_thresh']

        dataset_config = config['dataset_config']['datasets'][0]
        self.max_nobj = dataset_config['max_nobj']
        self.max_nR = dataset_config['max_nR']
        self.adj_thresh = (dataset_config['adj_radius_range'][0] + dataset_config['adj_radius_range'][1]) / 2
        self.fps_radius = (dataset_config['fps_radius_range'][0] + dataset_config['fps_radius_range'][1]) / 2
        self.topk = dataset_config['topk']
        self.connect_all = dataset_config['connect_all']

    def downsample_vertices(self, xyz):  # (n, 3)
        particle_tensor = xyz[None, ...].detach().cpu()
        fps_idx_1 = farthest_point_sampler(particle_tensor, self.max_nobj, start_idx=0)[0]
        downsampled_particle = particle_tensor[0, fps_idx_1, :]
        _, fps_idx_2 = fps_rad_idx_torch(downsampled_particle, self.fps_radius)
        fps_idx = fps_idx_1[fps_idx_2]
        xyz = xyz[fps_idx]
        return xyz, fps_idx

    def train_step(self, eval_mode=False):
        xyz_0, eef_xyz, n_steps = \
        self.xyz_0, self.eef_pos_tensor, self.frame_num
        model = self.model
        device = self.device
        all_pos = xyz_0
        if self.fps_all_idx is None:
            self.fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_idx = self.fps_all_idx
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[fps_all_idx]
        eef_num = self.eef_num
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)
        if self.fps_idx_second is None:
            _, self.fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_idx_second = self.fps_idx_second
        particle_pos_0 = fps_all_pos[fps_idx_second]

        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # results to store
        xyz = xyz_0.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        xyz_bones = torch.zeros(n_steps, self.max_nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []
        total_loss = 0.0

        Rr, Rs, rels = None, None, None
        for i in range(1, n_steps):
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            states = torch.zeros((1, self.n_his, nobj + eef_num, 3), device=device)
            states[:, :, :nobj] = particle_pos_history
            states[:, :, nobj:] = eef_pos_history

            states_delta = torch.zeros((1, nobj + eef_num, 3), device=device)
            states_delta[:, nobj:] = eef_delta

            attrs = torch.zeros((1, nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, nobj:, 1] = 1.

            p_instance = torch.ones((1, nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.ones((1, nobj + eef_num), dtype=bool, device=device)

            eef_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, nobj:] = 1

            obj_mask = torch.zeros((1, nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            # TODO debug
            if i == 1:
                Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
                                                           mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
                                                           connect_all=self.connect_all, return_rels=True)
                Rr = Rr[None]
                Rs = Rs[None]
            rels_list.append(rels)

            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            logE = torch.zeros_like(collider_distance)
            fps_logE = self.normalized_E[fps_all_idx][fps_idx_second]
            logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = self.friction[0]

            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)
            loss = self.loss_func(pred_state, self.xyz_gt[i][None])
            if not eval_mode:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item()
            pred_pos.append(pred_state[0].cpu())

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0).detach()
            eef_pos = eef_pos_this_step

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0).detach()
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0].cpu()
            eef[i] = eef_pos.cpu()

        return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss

    def train_step_structure(self, eval_mode=False, return_test_loss=False):
        xyz_0, eef_xyz, n_steps = \
        self.structure_points, self.eef_pos_tensor, self.frame_num
        model = self.model
        device = self.device
        all_pos = xyz_0
        if self.fps_all_idx is None:
            self.fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_idx = self.fps_all_idx
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[self.fps_idx_1000][fps_all_idx]
        eef_num = self.eef_num
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)
        if self.fps_idx_second is None:
            self.particle_pos_0, self.fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_idx_second = self.fps_idx_second
        particle_pos_0 = fps_all_pos[fps_idx_second]

        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # results to store
        xyz = self.structure_points.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        nobj = len(fps_idx_second)
        xyz_bones = torch.zeros(n_steps, nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []
        total_loss = 0.0
        total_test_loss = 0.0

        relation = get_topk_indices(particle_pos_0, K=5)
        all_pos = self.structure_points.clone()
        weights = knn_weights_new(particle_pos_0, all_pos, K=5)
        start_frame = 1
        for i in range(start_frame, n_steps):
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            if i <= start_frame:
                states = torch.zeros((1, self.n_his, self.max_nobj + eef_num, 3), device=device)
                states[:, :, :nobj] = particle_pos_history
                states[:, :, self.max_nobj:] = eef_pos_history
            else:
                next_state = torch.zeros((1, 1, self.max_nobj + eef_num, 3), device=device)
                next_state[:, :, self.max_nobj:] = eef_pos
                next_state[:, :, :self.max_nobj] = pred_state.detach()
                states = torch.cat([states.detach()[:, 1:], next_state], dim=1)

            states_delta = torch.zeros((1, self.max_nobj + eef_num, 3), device=device)
            states_delta[:, self.max_nobj:] = eef_delta

            attrs = torch.zeros((1, self.max_nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, self.max_nobj:, 1] = 1.

            p_instance = torch.ones((1, self.max_nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)

            state_mask[:, :nobj] = True
            state_mask[:, self.max_nobj:] = True

            eef_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, self.max_nobj:] = 1

            obj_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            if self.is_push:
                Rr_obj, Rs_obj, Rr_tool, Rs_tool, edge_obj, edge_tool = (
                    construct_edges_from_states_sep(states[0, -1], self.adj_thresh, mask=state_mask[0], tool_mask=eef_mask[0],
                                                    topk=self.topk, connect_all=self.connect_all, return_rels=True))
                Rr = torch.concat((self.Rr, Rr_tool))[None]
                Rs = torch.concat((self.Rs, Rs_tool))[None]
                rels = torch.concat((self.rels, edge_tool))
            else:
                Rr = pad_torch(self.Rr, self.max_nR)[None]
                Rs = pad_torch(self.Rs, self.max_nR)[None]
                rels = self.rels
            # TODO debug
            # if i == 1:
            #     Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
            #                                                mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
            #                                                connect_all=self.connect_all, return_rels=True)
            #     Rr = Rr[None]
            #     Rs = Rs[None]
            rels_list.append(rels)


            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            # collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(),
            #                                        max=-self.adj_thresh)
            # if i <= start_frame:
            #     collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), max=-self.adj_thresh)
            # else:
            #     collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)

            # collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(),
            #                                        max=-self.adj_thresh)
            E_dim = 1 if len(self.normalized_E.shape) == 1 else 3
            fps_logE = self.normalized_E[fps_all_idx][self.nearest_indices].mean(dim=1)
            logE = torch.zeros((1, self.max_nobj + eef_num, E_dim), dtype=collider_distance.dtype).to(device)
            # logE = torch.zeros_like(collider_distance)
            # fps_logE = self.normalized_E[fps_all_idx][fps_idx_second]
            # print(fps_logE)
            if E_dim == 1:
                logE[:, :nobj, 0] = fps_logE
            else:
                logE[:, :nobj] = fps_logE
            # logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = self.friction[0]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)

            # relation = get_topk_indices(particle_pos, K=5)
            # weights = knn_weights_new(particle_pos, all_pos, K=5)
            # 运动插值 - 根据关键点运动更新所有粒子
            all_pos, _, _ = inter_motion_grad(
                bones=particle_pos,  # 当前关键点位置
                motions=pred_state[0, :nobj] - particle_pos,  # 关键点位移
                relations=relation,  # 关系矩阵
                xyz=all_pos,  # 所有粒子当前位置
                weights=weights
            )
            chamfer_loss = self.chamfer_weight * self.chamfer_loss(all_pos[None], self.xyz_gt[i][None]) / len(all_pos)
            track_loss = self.track_weight * self.track_loss(all_pos[self.vis_indices[i]][None], self.xyz_gt[i][None])
            loss =  chamfer_loss + track_loss
            # loss = (self.chamfer_weight * self.chamfer_loss(all_pos[None], self.xyz_gt[i][None]) +
            #         self.track_weight * self.track_loss(all_pos[self.vis_indices[i]][None], self.xyz_gt[i][None]))
            # loss = (self.chamfer_weight * chamfer_distance(self.xyz_gt[i][None], all_pos[None], single_directional=True, norm=1)[0] +
            #         self.track_weight * self.track_loss(all_pos[self.vis_indices[i]][None], self.xyz_gt[i][None]))
            # loss = self.chamfer_loss(all_pos[None], self.xyz_gt[i][None])
            if (not eval_mode) and i < self.train_frame:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if i >= self.train_frame:
                total_test_loss += loss.item()
            total_loss += loss.item()
            pred_pos.append(pred_state[0, :nobj].cpu())
            all_pos = all_pos.detach()

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0).detach()
            eef_pos = eef_pos_this_step

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state[:, :nobj]], dim=0).detach()
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0, :nobj].cpu()
            eef[i] = eef_pos.cpu()
        if not return_test_loss:
            return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss
        else:
            return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss, total_test_loss

    def inference(self, eval_mode=False, return_test_loss=False):
        xyz_0, eef_xyz, n_steps = \
        self.structure_points, self.eef_pos_tensor, self.frame_num
        model = self.model
        device = self.device
        all_pos = xyz_0
        if self.fps_all_idx is None:
            self.fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_idx = self.fps_all_idx
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[self.fps_idx_1000][fps_all_idx]
        eef_num = self.eef_num
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)
        if self.fps_idx_second is None:
            self.particle_pos_0, self.fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_idx_second = self.fps_idx_second
        particle_pos_0 = fps_all_pos[fps_idx_second]

        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # results to store
        xyz = self.structure_points.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        nobj = len(fps_idx_second)
        xyz_bones = torch.zeros(n_steps, nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []
        total_loss = 0.0
        total_test_loss = 0.0

        relation = get_topk_indices(particle_pos_0, K=5)
        all_pos = self.structure_points.clone()
        weights = knn_weights_new(particle_pos_0, all_pos, K=5)
        start_frame = 1
        for i in range(start_frame, n_steps):
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            if i <= start_frame:
                states = torch.zeros((1, self.n_his, self.max_nobj + eef_num, 3), device=device)
                states[:, :, :nobj] = particle_pos_history
                states[:, :, self.max_nobj:] = eef_pos_history
            else:
                next_state = torch.zeros((1, 1, self.max_nobj + eef_num, 3), device=device)
                next_state[:, :, self.max_nobj:] = eef_pos
                next_state[:, :, :self.max_nobj] = pred_state.detach()
                states = torch.cat([states.detach()[:, 1:], next_state], dim=1)

            states_delta = torch.zeros((1, self.max_nobj + eef_num, 3), device=device)
            states_delta[:, self.max_nobj:] = eef_delta

            attrs = torch.zeros((1, self.max_nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, self.max_nobj:, 1] = 1.

            p_instance = torch.ones((1, self.max_nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)

            state_mask[:, :nobj] = True
            state_mask[:, self.max_nobj:] = True

            eef_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, self.max_nobj:] = 1

            obj_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            if self.is_push:
                Rr_obj, Rs_obj, Rr_tool, Rs_tool, edge_obj, edge_tool = (
                    construct_edges_from_states_sep(states[0, -1], self.adj_thresh, mask=state_mask[0], tool_mask=eef_mask[0],
                                                    topk=self.topk, connect_all=self.connect_all, return_rels=True))
                Rr = torch.concat((self.Rr, Rr_tool))[None]
                Rs = torch.concat((self.Rs, Rs_tool))[None]
                rels = torch.concat((self.rels, edge_tool))
            else:
                Rr = pad_torch(self.Rr, self.max_nR)[None]
                Rs = pad_torch(self.Rs, self.max_nR)[None]
                rels = self.rels
            # TODO debug
            # if i == 1:
            #     Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
            #                                                mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
            #                                                connect_all=self.connect_all, return_rels=True)
            #     Rr = Rr[None]
            #     Rs = Rs[None]
            rels_list.append(rels)


            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            # collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(),
            #                                        max=-self.adj_thresh)
            # if i <= start_frame:
            #     collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), max=-self.adj_thresh)
            # else:
            #     collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)

            # collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(),
            #                                        max=-self.adj_thresh)
            E_dim = 1 if len(self.normalized_E.shape) == 1 else 3
            fps_logE = self.normalized_E[fps_all_idx][self.nearest_indices].mean(dim=1)
            logE = torch.zeros((1, self.max_nobj + eef_num, E_dim), dtype=collider_distance.dtype).to(device)
            # logE = torch.zeros_like(collider_distance)
            # fps_logE = self.normalized_E[fps_all_idx][fps_idx_second]
            # print(fps_logE)
            if E_dim == 1:
                logE[:, :nobj, 0] = fps_logE
            else:
                logE[:, :nobj] = fps_logE
            # logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = self.friction[0]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)

            # relation = get_topk_indices(particle_pos, K=5)
            # weights = knn_weights_new(particle_pos, all_pos, K=5)
            # 运动插值 - 根据关键点运动更新所有粒子

            pred_pos.append(pred_state[0, :nobj].cpu())
            all_pos = all_pos.detach()

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0).detach()
            eef_pos = eef_pos_this_step

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state[:, :nobj]], dim=0).detach()
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0, :nobj].cpu()
            eef[i] = eef_pos.cpu()
        if not return_test_loss:
            return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss
        else:
            return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss, total_test_loss

    def train_step_structure_lerp(self, eval_mode=False, return_test_loss=False):
        xyz_0, eef_xyz, n_steps = \
        self.structure_points, self.eef_pos_tensor, self.frame_num
        model = self.model
        device = self.device
        all_pos = xyz_0
        if self.fps_all_idx is None:
            self.fps_all_idx = farthest_point_sampler(xyz_0.cpu()[None], 1000, start_idx=0)[0]
        fps_all_idx = self.fps_all_idx
        # fps_all_idx = torch.arange(len(xyz_0), device=device)
        fps_all_pos = all_pos[self.fps_idx_1000][fps_all_idx]
        eef_num = self.eef_num
        if eef_num == 1:
            eef_pos_history = eef_xyz[0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[0]  # (1, 3)
        else:
            eef_pos_history = eef_xyz[:, 0][None].repeat(model.model_config['n_his'], 1, 1)  # (n_his, 1, 3)
            eef_pos = eef_xyz[:, 0]  # (1, 3)
        if self.fps_idx_second is None:
            self.particle_pos_0, self.fps_idx_second = self.downsample_vertices(fps_all_pos.clone())
        fps_idx_second = self.fps_idx_second
        particle_pos_0 = fps_all_pos[fps_idx_second]

        fps_all_pos_history = fps_all_pos[fps_idx_second][None].repeat(model.model_config['n_his'], 1,
                                                                       1)  # (n_his, n_particles, 3)
        # results to store
        xyz = self.structure_points.cpu()[None].repeat(n_steps, 1, 1)  # (n_steps, n_particles, 3)
        nobj = len(fps_idx_second)
        xyz_bones = torch.zeros(n_steps, nobj, 3)  # (n_steps,5 n_bones, 3)
        if eef_num == 1:
            eef = eef_xyz.cpu()[0][None].repeat(n_steps, 1, 1)  # (n_steps, 1, 3)
        else:
            eef = eef_xyz.cpu()[:, 0][None].repeat(n_steps, 1, 1)
        xyz_bones[0, :particle_pos_0.shape[0]] = particle_pos_0.cpu()

        key_point = []
        pred_pos = []
        gt_pos = []
        pred_pos.append(particle_pos_0.cpu())
        gt_pos.append(particle_pos_0.cpu())
        rels_list = []
        total_loss = 0.0
        total_test_loss = 0.0

        relation = get_topk_indices(particle_pos_0, K=5)
        all_pos = self.structure_points.clone()
        weights = knn_weights_new(particle_pos_0, all_pos, K=5)
        start_frame = 1
        for i in range(start_frame, n_steps):
            if eef_num == 1:
                eef_pos_this_step = eef_xyz[i]
            else:
                eef_pos_this_step = eef_xyz[:, i]
            eef_delta = eef_pos_this_step - eef_pos

            # particle_pos, fps_idx = self.downsample_vertices(fps_all_pos.clone())
            particle_pos = fps_all_pos_history[-1]

            key_point.append(particle_pos)
            particle_pos_history = fps_all_pos_history
            nobj = particle_pos.shape[0]

            if i <= start_frame:
                states = torch.zeros((1, self.n_his, self.max_nobj + eef_num, 3), device=device)
                states[:, :, :nobj] = particle_pos_history
                states[:, :, self.max_nobj:] = eef_pos_history
            else:
                next_state = torch.zeros((1, 1, self.max_nobj + eef_num, 3), device=device)
                next_state[:, :, self.max_nobj:] = eef_pos
                next_state[:, :, :self.max_nobj] = pred_state.detach()
                states = torch.cat([states.detach()[:, 1:], next_state], dim=1)

            states_delta = torch.zeros((1, self.max_nobj + eef_num, 3), device=device)
            states_delta[:, self.max_nobj:] = eef_delta

            attrs = torch.zeros((1, self.max_nobj + eef_num, 2), dtype=torch.float32, device=device)
            attrs[:, :nobj, 0] = 1.
            attrs[:, self.max_nobj:, 1] = 1.

            p_instance = torch.ones((1, self.max_nobj, 1), dtype=torch.float32, device=device)

            state_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)

            state_mask[:, :nobj] = True
            state_mask[:, self.max_nobj:] = True

            eef_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)
            eef_mask[:, self.max_nobj:] = 1

            obj_mask = torch.zeros((1, self.max_nobj + eef_num), dtype=bool, device=device)
            obj_mask[:, :nobj] = 1

            if self.is_push:
                Rr_obj, Rs_obj, Rr_tool, Rs_tool, edge_obj, edge_tool = (
                    construct_edges_from_states_sep(states[0, -1], self.adj_thresh, mask=state_mask[0], tool_mask=eef_mask[0],
                                                    topk=self.topk, connect_all=self.connect_all, return_rels=True))
                Rr = torch.concat((self.Rr, Rr_tool))
                Rs = torch.concat((self.Rs, Rs_tool))
                rels = torch.concat((self.rels, edge_tool))
            else:
                Rr = pad_torch(self.Rr, self.max_nR)[None]
                Rs = pad_torch(self.Rs, self.max_nR)[None]
                rels = self.rels
            # TODO debug
            # if i == 1:
            #     Rr, Rs, rels = construct_edges_from_states(states[0, -1], self.adj_thresh,
            #                                                mask=state_mask[0], tool_mask=eef_mask[0], topk=self.topk,
            #                                                connect_all=self.connect_all, return_rels=True)
            #     Rr = Rr[None]
            #     Rs = Rs[None]
            rels_list.append(rels)


            collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)
            # collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(),
            #                                        max=-self.adj_thresh)
            # if i <= start_frame:
            #     collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), max=-self.adj_thresh)
            # else:
            #     collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(), -self.adj_thresh)

            # collider_distance = -1.0 * torch.clamp(states[:, -1, :, -1][..., torch.newaxis].clone(),
            #                                        max=-self.adj_thresh)
            E_dim = 1 if len(self.normalized_E.shape) == 1 else 3
            fps_logE = self.normalized_E[fps_all_idx][self.nearest_indices].mean(dim=1)
            logE = torch.zeros((1, self.max_nobj + eef_num, E_dim), dtype=collider_distance.dtype).to(device)
            # logE = torch.zeros_like(collider_distance)
            # fps_logE = self.normalized_E[fps_all_idx][fps_idx_second]
            # print(fps_logE)
            if E_dim == 1:
                logE[:, :nobj, 0] = fps_logE
            else:
                logE[:, :nobj] = fps_logE
            # logE[:, :nobj, 0] = fps_logE
            frictions = torch.zeros_like(collider_distance)
            frictions[:, :nobj, 0] = self.friction[0]


            graph = {
                # input information
                "state": states,  # (n_his, N+M, state_dim)
                "action": states_delta,  # (N+M, state_dim)

                # attr information
                "attrs": attrs,  # (N+M, attr_dim)
                # "p_rigid": p_rigid,  # (n_instance,)
                "p_instance": p_instance,  # (N, n_instance)
                "obj_mask": obj_mask,  # (N,)
                "state_mask": state_mask,  # (N+M,)
                "eef_mask": eef_mask,  # (N+M,)

                "Rr": Rr,  # (bsz, max_nR, N)
                "Rs": Rs,  # (bsz, max_nR, N)
                "collider_distance": collider_distance,
                "logE": logE,
                "friction": frictions
            }

            pred_state, _ = model(**graph)  # (1, nobj, 3)

            # relation = get_topk_indices(particle_pos, K=5)
            # weights = knn_weights_new(particle_pos, all_pos, K=5)
            # 运动插值 - 根据关键点运动更新所有粒子
            all_pos, _, _ = inter_motion_grad(
                bones=particle_pos,  # 当前关键点位置
                motions=pred_state[0, :nobj] - particle_pos,  # 关键点位移
                relations=relation,  # 关系矩阵
                xyz=all_pos,  # 所有粒子当前位置
                weights=weights
            )

            chamfer_loss = self.chamfer_weight * self.chamfer_loss(all_pos[None], self.xyz_gt[i][None]) / len(all_pos)
            track_loss = self.track_weight * self.track_loss(all_pos[self.vis_indices[i]][None], self.xyz_gt[i][None])
            loss =  chamfer_loss + track_loss
            # loss = (self.chamfer_weight * self.chamfer_loss(all_pos[None], self.xyz_gt[i][None]) +
            #         self.track_weight * self.track_loss(all_pos[self.vis_indices[i]][None], self.xyz_gt[i][None]))
            # loss = (self.chamfer_weight * chamfer_distance(self.xyz_gt[i][None], all_pos[None], single_directional=True, norm=1)[0] +
            #         self.track_weight * self.track_loss(all_pos[self.vis_indices[i]][None], self.xyz_gt[i][None]))
            # loss = self.chamfer_loss(all_pos[None], self.xyz_gt[i][None])
            if (not eval_mode) and i < self.train_frame:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            if i >= self.train_frame:
                total_test_loss += loss.item()
            total_loss += loss.item()
            pred_pos.append(pred_state[0, :nobj].cpu())
            all_pos = all_pos.detach()

            eef_pos_history = torch.cat([eef_pos_history[1:], eef_pos_this_step[None]], dim=0).detach()
            eef_pos = eef_pos_this_step

            gt_pos.append(fps_all_pos[fps_idx_second].cpu().numpy())
            fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state[:, :nobj]], dim=0).detach()
            # fps_all_pos_history = torch.cat([fps_all_pos_history[1:], pred_state], dim=0)

            xyz[i] = all_pos.cpu()
            xyz_bones[i, :nobj] = pred_state[0, :nobj].cpu()
            eef[i] = eef_pos.cpu()
        if not return_test_loss:
            return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss
        else:
            return xyz, xyz_bones, eef, pred_pos, gt_pos, rels_list, total_loss, total_test_loss

    def LBS_render(self, xyz_bones, save_path):
        xyz_bones = xyz_bones.to(device)
        os.makedirs(save_path, exist_ok=True)
        all_pos = self.gs_params['xyz_ply']
        all_rot = self.gs_params['rots_ply']
        precomp_colors = self.gs_params['precomp_colors']
        opacities = self.gs_params['opacities_ply']
        scales = self.gs_params['scales_ply']
        rendervar = {
            'means3D': None,
            'colors_precomp': precomp_colors,
            'rotations': None,
            'opacities': opacities,
            'scales': scales,
            'means2D': torch.zeros_like(all_pos),
        }
        relation = get_topk_indices(xyz_bones[0], K=5)
        weights = knn_weights_new(xyz_bones[0], all_pos, K=5)
        for i in range(len(xyz_bones)-1):
            all_pos, all_rot, _ = inter_motion(
                bones=xyz_bones[i],  # 当前关键点位置
                motions=xyz_bones[i+1] - xyz_bones[i],  # 关键点位移
                relations=relation,  # 关系矩阵
                xyz=all_pos,  # 所有粒子当前位置
                quat=all_rot,  # 所有粒子当前旋转
                weights=weights
            )
            rendervar['means3D'] = all_pos
            rendervar['rotations'] = all_rot

            #render 3DGS
            im, depth = self.renderer.render(self.w2c, self.k, rendervar, bg=[1.0, 1.0, 1.0])
            im = im.detach().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1] * 255.0
            im = im.copy().astype(np.uint8)
            im = im[0:480, 0:848]

            #render points
            # im = np.zeros((720, 1280, 3), dtype=np.uint8)
            # draw_points(im, self.w2c, self.k, self.xyz_gt[i+1], color=(0, 255, 0))
            # draw_points(im, self.w2c, self.k, xyz_bones[i+1], color=(0, 0, 255))
            # if self.eef_num == 1:
            #     draw_points(im, self.w2c, self.k, self.eef_pos[i+1][None], color=(255, 0, 0), radius=3)
            # else:
            #     draw_points(im, self.w2c, self.k, self.eef_pos[:, i + 1], color=(255, 0, 0), radius=3)
            #
            cv2.imwrite(f'{save_path}/frame_{i:04d}.png', im)

            # rels = self.rels
            # if self.eef_num == 1:
            #     plot_3d_top_view_equal_scale(xyz_bones[i], self.xyz_gt[i],rels = rels,eef=self.eef_pos[i][None], output_file=f'../temp_pic/{i}.png')
            # else:
            #     plot_3d_top_view_equal_scale(xyz_bones[i], self.xyz_gt[i], rels=rels, eef=self.eef_pos[:, i], max_nobj=self.max_nobj,
            #                                  output_file=f'../temp_pic/{i}.png')

    def save_points_to_pkl(self, save_path, data):
        data_array = np.array(data)
        pkl_path = os.path.join(save_path,'inference.pkl')
        with open(pkl_path, 'wb') as f:
            pkl.dump(data_array, f)
        print(f"pkl save to {pkl_path}")

    def LBS_structure_points(self, xyz_bones, epoch):
        xyz_bones = xyz_bones.to(device)
        all_pos = self.structure_points.clone()
        save_pos = [all_pos.cpu().numpy()]
        relation = get_topk_indices(xyz_bones[0], K=5)
        weights = knn_weights_new(xyz_bones[0], all_pos, K=5)
        for i in range(len(xyz_bones)-1):
            all_pos, _, _ = inter_motion(
                bones=xyz_bones[i],  # 当前关键点位置
                motions=xyz_bones[i + 1] - xyz_bones[i],  # 关键点位移
                relations=relation,  # 关系矩阵
                xyz=all_pos,  # 所有粒子当前位置
                weights=weights
            )
            save_pos.append(all_pos.cpu().numpy())
        save_pos = torch.stack(save_pos)
        save_path = os.path.join(self.save_root, f'{epoch}')
        os.makedirs(save_path, exist_ok=True)
        self.save_points_to_pkl(save_path, save_pos)

    def draw_points(self, xyz, epoch):
        os.makedirs(f'{self.save_root}/draw_{epoch}', exist_ok=True)
        for i in range(len(xyz) - 1):
            kp = xyz[i]
            # if self.eef_num == 1:
            #     kp_ee = np.concatenate([kp, self.eef_pos[i][None]])
            # else:
            #     kp_ee = np.concatenate([kp, self.eef_pos[:, i]])
            # gt = kp_gt[i]
            gs_fps_idx = farthest_point_sampler(self.xyz_gt[i].cpu()[None],100, start_idx=0)[0]
            view_gt = self.xyz_gt[i].cpu()[gs_fps_idx]
            kp_fps_idx= farthest_point_sampler(kp.cpu()[None],100, start_idx=0)[0]
            view_kp = kp[kp_fps_idx]
            plot_3d_top_view_equal_scale(view_kp, view_gt, rels=None, output_file=f'{self.save_root}/draw_{epoch}/{i}.png')


    def finetune_phys(self):
        loss_plot = []
        for i in range(self.epoch):
            if i % 10 == 0:
                with torch.no_grad():
                    xyz, xyz_bones, eef, key_point, kp_gt, rels, total_loss = self.train_step_structure(eval_mode=True)
                    render_path = os.path.join(self.save_root, 'render')
                    os.makedirs(render_path, exist_ok=True)
                    render_path_epoch = os.path.join(render_path, f'{i}')
                    os.makedirs(render_path_epoch, exist_ok=True)
                    # self.draw_points(xyz.cpu(), i)
                    self.LBS_render(xyz_bones, render_path_epoch)
                    self.LBS_structure_points(xyz_bones, i)

            else:
                xyz, xyz_bones, eef, key_point, kp_gt, rels, total_loss = self.train_step_structure()
                self.scheduler.step()
            print('epoch: {}, total_loss: {}, logE_mean: {}, friction: {}, LR:{}'.format(i,
                                                                                         total_loss,
                                                                                         self.normalized_E[
                                                                                             self.fps_all_idx].detach().mean().item(),
                                                                                         self.friction[
                                                                                             0].detach().item(),
                                                                                         self.scheduler.get_lr()))
            loss_plot.append(total_loss)
            plt.plot(loss_plot)
            plt.savefig(os.path.join(self.save_root, 'loss.png'), dpi=300)

    def finetune_phys_pure(self):
        loss_plot = []
        best_loss = 1e10
        for i in range(self.epoch):
            with torch.no_grad():
                xyz, xyz_bones, eef, key_point, kp_gt, rels, total_loss, total_loss_test = self.train_step_structure(eval_mode=True, return_test_loss=True)
            if total_loss < best_loss:
                best_loss = total_loss
                np.save(os.path.join(self.save_root, 'xyz_bones_best.npy'), xyz_bones)
                self.save_points_to_pkl(self.save_root, xyz)
            if i % 10 == 0:
                save_p = os.path.join(self.save_root, f'{i}')
                os.makedirs(save_p, exist_ok=True)
                np.save(os.path.join(save_p, 'xyz_bones.npy'), xyz_bones)
                self.save_points_to_pkl(save_p, xyz)
                # self.LBS_structure_points(xyz_bones, i, save_p)
                # os.makedirs(os.path.join(self.save_root, f'{i}'), exist_ok=True)
                # self.save_points_to_pkl(os.path.join(self.save_root, f'{i}'), xyz)
                # np.save(os.path.join(self.save_root, 'inference.npy'), xyz_bones.cpu().numpy())

                render_path = os.path.join(self.save_root, 'render')
                os.makedirs(render_path, exist_ok=True)
                render_path_epoch = os.path.join(render_path, f'{i}')
                os.makedirs(render_path_epoch, exist_ok=True)
                self.LBS_render(xyz_bones, render_path_epoch)

                torch.save({'normalized_E':self.normalized_E.detach().cpu(), 'friction':self.friction.detach().cpu()
                            }, os.path.join(save_p, 'optimized_params.pth'))
            xyz, xyz_bones, eef, key_point, kp_gt, rels, total_train_loss = self.train_step_structure()
            self.scheduler.step()
            print('epoch: {}, total_loss: {}, logE_mean: {}, friction: {}, LR:{}'.format(i,
                                                                                         total_loss,
                                                                                         self.normalized_E[
                                                                                             self.fps_all_idx].detach().mean().item(),
                                                                                         self.friction[
                                                                                             0].detach().item(),
                                                                                         self.scheduler.get_lr()))
            loss_plot.append(total_loss)
            plt.plot(loss_plot)
            plt.savefig(os.path.join(self.save_root, 'loss.png'), dpi=300)

    def test_time(self):
        loss_plot = []
        best_loss = 1e10
        torch.cuda.synchronize()
        start_time = time.time()
        for i in range(10):
            with torch.no_grad():
                xyz, xyz_bones, eef, key_point, kp_gt, rels, total_loss, total_loss_test = self.inference(eval_mode=True, return_test_loss=True)
        torch.cuda.synchronize()
        end_time = time.time()
        inference_time = (end_time - start_time) / 10.0
        print(f'PhysWorld Inference time: {inference_time}')

    def draw_track(self):
        track_path = '/data/dev/PhysTwin/data/different_types/double_lift_sloth'
        with open(os.path.join(track_path, 'final_data.pkl'), 'rb') as f:
            data = pkl.load(f)
        obj_points = data['object_points']
        save_path = '/data/dev/PhysTwin/data/different_types/double_lift_sloth/draw_track'
        for t in range(len(obj_points)):
            points = obj_points[t]
            im = draw_points_2(self.w2c, self.k, points, color=(255, 216, 173), radius=2)
            im = im[0:480, 0:848]
            cv2.imwrite(f'{save_path}/frame_{t:04d}.png', im)



if __name__ == '__main__':
    # parser = argparse.ArgumentParser(description='Generate simulation data')
    # # 添加必需的参数
    # parser.add_argument('--name', type=str,  default='double_lift_sloth',
    #                    help='Name of the configuration (e.g. double_lift_cloth_3)')
    #
    # args = parser.parse_args()
    # name = args.name
    name = 'double_lift_sloth'
    ori_data_path = f'/data/dev/PhysGaussian_rebuild_2/data/phystwin_data/{name}'
    ply_path = f'/data/dev/PhysGaussian_rebuild_2/data/phystwin_data/{name}/gaussian_output'
    test_data_path = f'../test_data/{name}'
    mpm_test_path = f'../test_data/{name}_MPM'
    config_path = f'./mpm_config/{name}.yaml'
    optimizer = PhysFinetuner(test_data_path, config_path, ori_data_path, ply_path, name, mpm_test_path)
    # optimizer.finetune_phys_pure()
    optimizer.draw_track()
    # optimizer.finetune_phys()
    # optimizer.test_time()
