import os
from pathlib import Path
import torch
import numpy as np
import glob
from PIL import Image
import cv2
import pickle as pkl
import json
import yaml
import argparse
import open3d as o3d
from dgl.geometry import farthest_point_sampler
from scipy.spatial.distance import cdist

from data.utils import label_colormap, opengl2cam
from real_world.utils.pcd_utils import depth2fgpcd, rpy_to_rotation_matrix
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm


def get_eef_points(xyz, rpy, calib):
    R_gripper2base = rpy_to_rotation_matrix(rpy[0], rpy[1], rpy[2])
    t_gripper2base = np.array(xyz) / 1000

    gripper_point = np.array([[0.0, 0.0, 0.17]])  # gripper  origin:0.18

    R_base2world = calib['R_base2world']
    t_base2world = calib['t_base2world']
    R_gripper2world = R_base2world @ R_gripper2base
    t_gripper2world = R_base2world @ t_gripper2base + t_base2world
    gripper_points_in_world = R_gripper2world @ gripper_point.T + t_gripper2world[:, np.newaxis]
    gripper_points_in_world = gripper_points_in_world.T

    return gripper_points_in_world[0]  # only one point


def test_validity(data_dir, output_dir):  # test if the camera recording is problematic
    params_dir = os.path.join(output_dir, 'params.npz')
    params_dir_0 = os.path.join(output_dir, 'params_0.npz')
    if not os.path.exists(params_dir) and not os.path.exists(params_dir_0):
        raise ValueError(f'Params dir {params_dir} not found')

    # with open(os.path.join(output_dir, 'train_meta.json'), 'r') as f:
    #     meta = json.load(f)
    with open(os.path.join(data_dir, 'train_meta.json'), 'r') as f:
        meta = json.load(f)
    fn = np.array(meta['fn'])  # n_frames, 4

    frame_idx_lists = []
    for i in range(len(fn)):
        frame_idx = int(fn[i][0].split('/')[-1].split('_')[1].split('.')[0])
        frame_idx_lists.append(frame_idx)
    frame_idx_lists = np.array(frame_idx_lists)
    num_frames = len(frame_idx_lists)

    with open(os.path.join(data_dir, 'actions.txt'), 'r') as f:
        json_data = f.read()
    json_data = json_data.rstrip('\n').split('\n')  # a list of length len(fn)

    if len(json_data) - num_frames < -10:
        print(f'warning: in data_dir {data_dir}, json_data length {len(json_data)}, num_frames {num_frames}')
        return False
    return True


def extract_pushes(data_dir, output_dir, save_dir, dist_thresh, n_his, n_future, episode_idx=0):
    # use overlapping samples
    # provide canonical frame info
    # compatible to other data layouts (make a general episode list)
    frame_idx_dir = os.path.join(save_dir, 'frame_pairs')
    os.makedirs(frame_idx_dir, exist_ok=True)

    params = load_sep_params(output_dir)
    xyz = params['means3D']  # n_frames, n_particles, 3

    with open(os.path.join(data_dir, 'train_meta.json'), 'r') as f:
        meta = json.load(f)
    # with open(os.path.join(output_dir, 'metadata.json'), 'r') as f:
    #     meta = json.load(f)
    fn = np.array(meta['fn'])  # n_frames, 4

    frame_idx_lists = []
    for i in range(len(fn)):
        frame_idx = int(fn[i][0].split('/')[-1].split('_')[1].split('.')[0])
        frame_idx_lists.append(frame_idx)
    frame_idx_lists = np.array(frame_idx_lists)
    num_frames = len(frame_idx_lists)

    with open(os.path.join(data_dir, 'actions.txt'), 'r') as f:
        json_data = f.read()
    json_data = json_data.rstrip('\n').split('\n')  # a list of length len(fn)

    if len(json_data) != num_frames:
        # print(f'warning: json_data length {len(json_data)}, num_frames {num_frames}')
        # deal with camera frame recording mismatch
        json_data = [json_data[0]] * (max(frame_idx_lists) + 1 - len(json_data)) + json_data

    if len(json_data) - num_frames > 10:
        json_data = json_data[:num_frames]

    joint_angles = []
    poses = []
    for frame_idx in range(len(frame_idx_lists)):
        try:
            actions = json.loads(json_data[frame_idx_lists[frame_idx]])
        except:
            # import ipdb; ipdb.set_trace()
            actions = json.loads(json_data[-1])
        joint_angles.append(actions['joint_angles'])
        poses.append(actions['pose'])
    joint_angles = np.array(joint_angles)
    poses = np.array(poses)

    with open(os.path.join(data_dir, "calibration_handeye_result.pkl"), "rb") as f:
        calib = pkl.load(f)

    eef_xyz = poses[:, :3]
    eef_xyz[:, 1] -= 1.0  # slightly adjust calibration error
    eef_rpy = poses[:, 3:]

    # generate pairs
    frame_idxs = []
    dists = []

    # get start-end pairs
    cnt = 0
    for curr_frame in range(num_frames):

        first_frame = 0
        end_frame = num_frames

        eef_particles_curr = get_eef_points(eef_xyz[curr_frame], eef_rpy[curr_frame], calib)

        frame_traj = [curr_frame]
        dist_traj = []

        # search backward
        fi = curr_frame
        while fi >= first_frame:
            eef_particles_fi = get_eef_points(eef_xyz[fi], eef_rpy[fi], calib)
            x_curr = eef_particles_curr[0]
            y_curr = eef_particles_curr[1]
            z_curr = eef_particles_curr[2]
            x_fi = eef_particles_fi[0]
            y_fi = eef_particles_fi[1]
            z_fi = eef_particles_fi[2]
            dist_curr = np.sqrt((x_curr - x_fi) ** 2 + (y_curr - y_fi) ** 2 + (z_curr - z_fi) ** 2)
            # 加入robot-boject contact

            if dist_curr >= dist_thresh:
                frame_traj.append(fi)
                dist_traj.append(dist_curr)
                eef_particles_curr = eef_particles_fi
            fi -= 1
            if len(frame_traj) == n_his:
                break
        else:
            # pad to n_his
            curr_len = len(frame_traj)
            frame_traj = frame_traj + [frame_traj[-1]] * (n_his - curr_len)
            dist_traj = dist_traj + [0] * (n_his - curr_len)

        frame_traj = frame_traj[::-1]
        dist_traj = dist_traj[::-1]
        fi = curr_frame
        eef_particles_curr = get_eef_points(eef_xyz[curr_frame], eef_rpy[curr_frame], calib)

        # search forward
        while fi < end_frame:
            eef_particles_fi = get_eef_points(eef_xyz[fi], eef_rpy[fi], calib)
            x_curr = eef_particles_curr[0]
            y_curr = eef_particles_curr[1]
            z_curr = eef_particles_curr[2]
            x_fi = eef_particles_fi[0]
            y_fi = eef_particles_fi[1]
            z_fi = eef_particles_fi[2]
            dist_curr = np.sqrt((x_curr - x_fi) ** 2 + (y_curr - y_fi) ** 2 + (z_curr - z_fi) ** 2)

            if dist_curr >= dist_thresh or (fi == end_frame - 1 and dist_curr >= 0.75 * dist_thresh):
                frame_traj.append(fi)
                dist_traj.append(dist_curr)
                eef_particles_curr = eef_particles_fi

            fi += 1
            if len(frame_traj) == n_his + n_future:
                break
        else:
            # When assuming quasi-static, we can pad to n_his + n_future
            curr_len = len(frame_traj)
            frame_traj = frame_traj + [frame_traj[-1]] * (n_his + n_future - curr_len)
            dist_traj = dist_traj + [0] * (n_his + n_future - curr_len)

        # cnt += 1
        # frame_idxs.append(frame_traj)
        # dists.append(dist_traj)

        rob_obj_dist_thresh = 0.09
        frame_idx = frame_traj[n_his - 1]
        end_effector_pos = get_eef_points(eef_xyz[frame_idx], eef_rpy[frame_idx], calib)
        robot_obj_dist = np.min(cdist(end_effector_pos[None], xyz[frame_idx]))
        if robot_obj_dist < rob_obj_dist_thresh:
            cnt += 1
            frame_idxs.append(frame_traj)
            dists.append(dist_traj)

        # if frame_traj[-1] < 150:
        #     continue
        # else:
        #     cnt += 1
        #     frame_idxs.append(frame_traj)
        #     dists.append(dist_traj)

        # push_centered
        if curr_frame == end_frame - 1:
            frame_idxs = np.array(frame_idxs)
            np.savetxt(os.path.join(frame_idx_dir, f'{episode_idx}.txt'), frame_idxs, fmt='%d')
            print(f'episode {episode_idx} has {cnt} unit pushes')
            frame_idxs = []
            dists = np.array(dists)
            np.savetxt(os.path.join(frame_idx_dir, f'dist_{episode_idx}.txt'), dists, fmt='%f')
            dists = []


def extract_pushes_pcd(data_dir, output_dir, save_dir, dist_thresh, n_his, n_future, episode_idx=0):
    # use overlapping samples
    # provide canonical frame info
    # compatible to other data layouts (make a general episode list)

    frame_idx_dir = os.path.join(save_dir, 'frame_pairs')
    os.makedirs(frame_idx_dir, exist_ok=True)

    with open(os.path.join(data_dir, 'train_meta.json'), 'r') as f:
        meta = json.load(f)
    # with open(os.path.join(output_dir, 'metadata.json'), 'r') as f:
    #     meta = json.load(f)
    fn = np.array(meta['fn'])  # n_frames, 4
    n = len(fn)
    # debug
    # try:
    params = load_sep_params(output_dir)
    xyz = params['means3D']  # n_frames, n_particles, 3
    # if xyz.shape[0] < 300:
    #     return
    xyz_mean = xyz.mean(axis=1)
    # n = len(xyz_mean)
    # xyz_diff = []
    xyz_diff = np.linalg.norm(xyz_mean[:-1] - xyz_mean[1:], axis=1)
    diff_thresh = max(xyz_diff[:100])
    # move_start_frame = np.where(xyz_diff > diff_thresh)[0].min()
    # print(f'episode: {data_dir} move_start: {move_start_frame}')
    # indices = list(np.where(xyz_diff > diff_thresh)[0])
    # expand = len(indices) // 10
    # indices_chose = list(range(max(0, indices[0]-expand), indices[0]))
    # indices_chose.extend(indices)
    # after = list(range(min(indices[-1]+1, n), min(indices[-1]+expand, n)))
    # indices_chose.extend(after)
    # except:
    #     diff_thresh = 0.0
    # indices_chose = list(range(n))
    # indices_chose = list(range(n))

    frame_idx_lists = []
    for i in range(len(fn)):
        frame_idx = int(fn[i][0].split('/')[-1].split('_')[1].split('.')[0])
        # if frame_idx in indices_chose:
        frame_idx_lists.append(frame_idx)
    frame_idx_lists = np.array(frame_idx_lists)
    num_frames = len(frame_idx_lists)

    with open(os.path.join(data_dir, 'actions.txt'), 'r') as f:
        json_data = f.read()
    json_data = json_data.rstrip('\n').split('\n')  # a list of length len(fn)

    if len(json_data) != num_frames:
        # print(f'warning: json_data length {len(json_data)}, num_frames {num_frames}')
        # deal with camera frame recording mismatch
        json_data = [json_data[0]] * (max(frame_idx_lists) + 1 - len(json_data)) + json_data

    if len(json_data) - num_frames > 10:
        json_data = json_data[:num_frames]

    joint_angles = []
    poses = []
    for frame_idx in range(len(frame_idx_lists)):
        try:
            actions = json.loads(json_data[frame_idx_lists[frame_idx]])
        except:
            # import ipdb; ipdb.set_trace()
            actions = json.loads(json_data[-1])
        joint_angles.append(actions['joint_angles'])
        poses.append(actions['pose'])
    joint_angles = np.array(joint_angles)
    poses = np.array(poses)

    with open(os.path.join(data_dir, "calibration_handeye_result.pkl"), "rb") as f:
        calib = pkl.load(f)

    eef_xyz = poses[:, :3]
    eef_xyz[:, 1] -= 1.0  # slightly adjust calibration error
    eef_rpy = poses[:, 3:]

    # generate pairs
    frame_idxs = []
    dists = []

    # get start-end pairs
    cnt = 0
    for curr_frame in range(num_frames):
        # debug
        # if curr_frame < 100:
        #     continue

        first_frame = 0
        end_frame = num_frames

        eef_particles_curr = get_eef_points(eef_xyz[curr_frame], eef_rpy[curr_frame], calib)

        frame_traj = [curr_frame]
        dist_traj = []

        # search backward
        fi = curr_frame
        while fi >= first_frame:
            eef_particles_fi = get_eef_points(eef_xyz[fi], eef_rpy[fi], calib)
            x_curr = eef_particles_curr[0]
            y_curr = eef_particles_curr[1]
            z_curr = eef_particles_curr[2]
            x_fi = eef_particles_fi[0]
            y_fi = eef_particles_fi[1]
            z_fi = eef_particles_fi[2]
            dist_curr = np.sqrt((x_curr - x_fi) ** 2 + (y_curr - y_fi) ** 2 + (z_curr - z_fi) ** 2)
            if dist_curr >= dist_thresh:
                frame_traj.append(fi)
                dist_traj.append(dist_curr)
                eef_particles_curr = eef_particles_fi
            fi -= 1
            if len(frame_traj) == n_his:
                break
        else:
            # pad to n_his
            curr_len = len(frame_traj)
            frame_traj = frame_traj + [frame_traj[-1]] * (n_his - curr_len)
            dist_traj = dist_traj + [0] * (n_his - curr_len)

        frame_traj = frame_traj[::-1]
        dist_traj = dist_traj[::-1]
        fi = curr_frame
        eef_particles_curr = get_eef_points(eef_xyz[curr_frame], eef_rpy[curr_frame], calib)

        # search forward
        while fi < end_frame:
            eef_particles_fi = get_eef_points(eef_xyz[fi], eef_rpy[fi], calib)
            x_curr = eef_particles_curr[0]
            y_curr = eef_particles_curr[1]
            z_curr = eef_particles_curr[2]
            x_fi = eef_particles_fi[0]
            y_fi = eef_particles_fi[1]
            z_fi = eef_particles_fi[2]
            dist_curr = np.sqrt((x_curr - x_fi) ** 2 + (y_curr - y_fi) ** 2 + (z_curr - z_fi) ** 2)

            if dist_curr >= dist_thresh or (fi == end_frame - 1 and dist_curr >= 0.75 * dist_thresh):
                frame_traj.append(fi)
                dist_traj.append(dist_curr)
                eef_particles_curr = eef_particles_fi

            fi += 1
            if len(frame_traj) == n_his + n_future:
                break
        else:
            # When assuming quasi-static, we can pad to n_his + n_future
            curr_len = len(frame_traj)
            frame_traj = frame_traj + [frame_traj[-1]] * (n_his + n_future - curr_len)
            dist_traj = dist_traj + [0] * (n_his + n_future - curr_len)

        # cnt += 1

        # if frame_traj[0] > indices_chose[0] and frame_traj[-1] < indices_chose[-1]:
        #     cnt += 1
        #     frame_idxs.append(frame_traj)
        #     dists.append(dist_traj)
        try:
            diff = np.linalg.norm(xyz_mean[frame_traj[n_his - 1]] - xyz_mean[frame_traj[n_his]])
            if diff > diff_thresh:
                cnt += 1
                frame_idxs.append(frame_traj)
                dists.append(dist_traj)
        except:
            print(f'diff error: {diff}')

        # debug
        # if frame_traj[0] > 150 and frame_traj[-1] < 250:
        #     cnt += 1
        #     frame_idxs.append(frame_traj)
        #     dists.append(dist_traj)

        # push_centered
        if curr_frame == end_frame - 1:
            if len(frame_idxs) < 1:
                return
            frame_idxs = np.array(frame_idxs)
            np.savetxt(os.path.join(frame_idx_dir, f'{episode_idx}.txt'), frame_idxs, fmt='%d')
            print(f'episode {episode_idx} has {cnt} unit pushes')
            frame_idxs = []
            dists = np.array(dists)
            np.savetxt(os.path.join(frame_idx_dir, f'dist_{episode_idx}.txt'), dists, fmt='%f')
            dists = []


def load_sep_params(output_dir):
    params_dir = os.path.join(output_dir, 'params.npz')
    original_params = np.load(params_dir) if os.path.exists(params_dir) else None
    separate_dirs = sorted(glob.glob(os.path.join(output_dir, 'params_*.npz')),
                           key=lambda x: int(os.path.splitext(os.path.basename(x))[0][7:]))
    params = None
    if len(separate_dirs) > 0:
        start_params = None
        is_params = True
        if os.path.exists(os.path.join(output_dir, 'params_0.npz')):
            is_params = False
            start_params_load = np.load(os.path.join(output_dir, 'params_0.npz'))
            start_params = {}
            for k in start_params_load.files:
                start_params[k] = start_params_load[k][None]
            separate_dirs.remove(os.path.join(output_dir, 'params_0.npz'))
        else:
            start_params = {k: original_params[k] for k in original_params.files}
        separate_params = []
        for separate_dir in separate_dirs:
            separate_params.append(np.load(separate_dir))
        sep_stack = {}
        for k in separate_params[0].files:
            sep_stack[k] = np.stack([params[k] for params in separate_params])
        params = {k: np.concatenate((start_params[k], sep_stack[k])) for k in sep_stack.keys()}
        for k in start_params.keys():
            if k not in sep_stack.keys():
                params[k] = start_params[k] if is_params else start_params[k][0]
        # params['logit_opacities'] = start_params['logit_opacities'] if is_params else start_params['logit_opacities'][0]
        if len(params['means3D']) > 600:
            for k in params.keys():
                params[k] = params[k][1:]

    else:
        params_dir = os.path.join(output_dir, 'params.npz')
        if not os.path.exists(params_dir):
            raise ValueError(f'Params dir {params_dir} not found')
        params = np.load(params_dir)
    return params


def my_downsample(output_dir):
    n_downsample = 1000

    params_dir = os.path.join(output_dir, 'params.npz')
    params_dir_0 = os.path.join(output_dir, 'params_0.npz')

    if not os.path.exists(params_dir) and not os.path.exists(params_dir_0):
        raise ValueError(f'Params dir {params_dir} not found')
    # params = np.load(params_dir)
    params = load_sep_params(output_dir)
    xyz = params['means3D']  # n_frames, n_particles, 3

    opacity_mask = (params['logit_opacities'] > 0)[:, 0]
    xyz = xyz[:, opacity_mask]

    xyz_motion = np.linalg.norm(xyz[1:] - xyz[:-1], axis=-1)
    xyz_motion_sum = np.sum(xyz_motion, axis=0)

    def detect_outliers(data, m):
        d = np.abs(data - np.median(data))
        mdev = np.median(d)
        s = d / mdev if mdev else np.zeros(len(d))
        return s < m

    valid_mask = detect_outliers(xyz_motion_sum, m=3)
    xyz = xyz[:, valid_mask]

    xyz_tensor = torch.from_numpy(xyz).float()
    fps_idx = farthest_point_sampler(xyz_tensor[0:1], n_downsample, start_idx=0)[0]
    xyz_tensor = xyz_tensor[:, fps_idx]
    xyz = xyz_tensor.numpy()

    # trajectory smoothing
    for _ in range(10):
        xyz[1:-1] = (xyz[:-2] + xyz[1:-1] + xyz[2:]) / 3

    np.save(os.path.join(output_dir, f"param_downsampled.npy"), xyz)
    print(f"Downsampled {xyz.shape[1]} points")


def downsample(output_dir):
    n_downsample = 1000

    params_dir = os.path.join(output_dir, 'params.npz')
    if not os.path.exists(params_dir):
        raise ValueError(f'Params dir {params_dir} not found')
    params = np.load(params_dir)
    xyz = params['means3D']  # n_frames, n_particles, 3

    opacity_mask = (params['logit_opacities'] > 0)[:, 0]
    xyz = xyz[:, opacity_mask]

    xyz_motion = np.linalg.norm(xyz[1:] - xyz[:-1], axis=-1)
    xyz_motion_sum = np.sum(xyz_motion, axis=0)

    def detect_outliers(data, m):
        d = np.abs(data - np.median(data))
        mdev = np.median(d)
        s = d / mdev if mdev else np.zeros(len(d))
        return s < m

    valid_mask = detect_outliers(xyz_motion_sum, m=3)
    xyz = xyz[:, valid_mask]

    xyz_tensor = torch.from_numpy(xyz).float()
    fps_idx = farthest_point_sampler(xyz_tensor[0:1], n_downsample, start_idx=0)[0]
    xyz_tensor = xyz_tensor[:, fps_idx]
    xyz = xyz_tensor.numpy()

    # trajectory smoothing
    for _ in range(10):
        xyz[1:-1] = (xyz[:-2] + xyz[1:-1] + xyz[2:]) / 3

    np.save(os.path.join(output_dir, f"param_downsampled.npy"), xyz)
    print(f"Downsampled {xyz.shape[1]} points")


def preprocess(config):
    base_dir = Path(config['dataset_config']['datasets'][0]['base_dir'])
    data_dir = base_dir / 'data'  # initial data such as rgbd, etc.
    output_dir = base_dir / 'ckpts'  # 3DGS model output
    preprocess_dir = base_dir / 'preprocessed'  # to save preprocessed data

    name = config['dataset_config']['datasets'][0]['name']  # 'episodes_giraffe_0521'
    exp_data_dir = data_dir / f'{name}'
    exp_output_dir = output_dir / f'exp_{name}'
    exp_preprocess_dir = preprocess_dir / f'exp_{name}'

    dist_thresh = config['train_config']['dist_thresh']
    n_his = config['train_config']['n_his']
    n_future = config['train_config']['n_future']

    print('dist_thresh', dist_thresh)

    # episodes = sorted(glob.glob(str(exp_output_dir / 'episode_*')))
    episodes = sorted(glob.glob(str(exp_output_dir / name / 'episode_*')))
    episode_idxs = [int(epi.split('_')[-1]) for epi in episodes]

    episodes = [epi for epi, idx in zip(episodes, episode_idxs) \
                if
                (os.path.exists(os.path.join(epi, 'params.npz')) or os.path.exists(os.path.join(epi, 'params_0.npz')))]
    # episodes = [epi for epi, idx in zip(episodes, episode_idxs) \
    #         if os.path.exists(os.path.join(epi, name, f'episode_{idx:02d}', 'params.npz'))]
    episode_idxs = [int(epi.split('_')[-1]) for epi in episodes]

    n_episodes = len(episode_idxs)
    print(f'Processing {n_episodes} episodes')

    if n_episodes == 0:
        no_episodes_ver = True
        episode_idxs = [0]
        n_episodes = 1
        exp_output_dir = output_dir / f'exp_{name}'
    else:
        no_episodes_ver = False

    for episode_idx in episode_idxs:
        print(f'Processing episode {episode_idx}')

        if no_episodes_ver:
            assert n_episodes == 1
            epi_data_dir = exp_data_dir
            epi_output_dir = exp_output_dir / name
            epi_preprocess_dir = exp_preprocess_dir
        else:
            epi_data_dir = exp_data_dir / f'episode_{episode_idx:02d}'
            epi_output_dir = exp_output_dir / name / f'episode_{episode_idx:02d}'
            # epi_output_dir = exp_output_dir / f'episode_{episode_idx:02d}' / name / f'episode_{episode_idx:02d}'
            epi_preprocess_dir = exp_preprocess_dir / f'episode_{episode_idx:02d}'

        if not test_validity(epi_data_dir, epi_output_dir):
            print(f'Episode {episode_idx} is invalid')
            continue

        os.makedirs(epi_preprocess_dir, exist_ok=True)

        extract_pushes(epi_data_dir, epi_output_dir, epi_preprocess_dir,
                       dist_thresh=dist_thresh, n_his=n_his, n_future=n_future, episode_idx=episode_idx)
        try:
            my_downsample(epi_output_dir)
            # downsample(epi_output_dir)
        except:
            print(f'Failed to downsample episode {episode_idx}')
            os.system(f'rm -r {epi_preprocess_dir}')
            continue

        # save metadata
        os.makedirs(epi_preprocess_dir, exist_ok=True)
        with open(os.path.join(epi_preprocess_dir, 'metadata.txt'), 'w') as f:
            f.write(f'{dist_thresh},{n_future},{n_his}')


def downsample_only(output_dir):
    n_downsample = 1000
    episodes_path = os.listdir(output_dir)
    for episode in episodes_path:
        episode_path = os.path.join(output_dir, episode)
        epi_data_path = os.path.join(episode_path, 'data.npz')
        data = dict(np.load(epi_data_path))
        xyz = data['trans_pos']
        # eef_start = data['trans_action'][0]
        # eef_end = data['trans_action'][1]
        eef_pos = data['trans_eef_pos']

        xyz_tensor = torch.from_numpy(xyz).float()
        fps_idx = farthest_point_sampler(xyz_tensor[0:1], n_downsample, start_idx=0)[0]
        xyz_tensor = xyz_tensor[:, fps_idx]
        xyz = xyz_tensor.numpy()
        # trajectory smoothing
        for _ in range(10):
            xyz[1:-1] = (xyz[:-2] + xyz[1:-1] + xyz[2:]) / 3
        np.save(os.path.join(episode_path, f"param_downsampled.npy"), xyz)
        print(f"episode{episode}, Downsampled {xyz.shape[1]} points")

        #primitive的speed为0.28
        speed = 0.28
        dt = 4e-2
        step_per_frame = 400

        # delta = eef_end - eef_start
        # distance = np.linalg.norm(delta)
        # direction = delta / (distance + 1e-6)
        # velocity = direction * speed
        # eef_pos = [eef_start + (velocity * dt * step_per_frame * i) for i in range(len(xyz))]
        np.save(os.path.join(episode_path, f"eef_pos.npy"), eef_pos)

def downsample_only_gripper(output_dir):
    n_downsample = 10000
    episodes_path = os.listdir(output_dir)
    for episode in episodes_path:
        episode_path = os.path.join(output_dir, episode)
        epi_data_path = os.path.join(episode_path, 'data.npz')
        E_path = os.path.join(episode_path, 'E_downsample.npy')
        if os.path.exists(E_path):
            continue
        data = dict(np.load(epi_data_path))
        xyz = data['pos']
        xyz_tensor = torch.from_numpy(xyz).float()
        fps_idx = farthest_point_sampler(xyz_tensor[0:1], n_downsample, start_idx=0)[0]
        xyz_tensor = xyz_tensor[:, fps_idx]
        xyz = xyz_tensor.numpy()

        fps_control_mask = data['control_mask'][fps_idx]
        assert fps_control_mask.sum() > 0
        fps_E = data['E'][fps_idx]
        control_velocity = data['velocity']
        # # trajectory smoothing
        # for _ in range(10):
        #     xyz[1:-1] = (xyz[:-2] + xyz[1:-1] + xyz[2:]) / 3
        np.save(os.path.join(episode_path, f"pos_downsampled.npy"), xyz)
        print(f"episode{episode}, Downsampled {xyz.shape[1]} points")
        np.save(os.path.join(episode_path, f"control_mask_downsample.npy"), fps_control_mask)
        np.save(os.path.join(episode_path, f"E_downsample.npy"), fps_E)
        np.save(os.path.join(episode_path, f"control_velocity.npy"), control_velocity)

def print_bad_npz(output_dir):
    n_downsample = 10000
    episodes_path = os.listdir(output_dir)
    for episode in episodes_path:
        episode_path = os.path.join(output_dir, episode)
        epi_data_path = os.path.join(episode_path, 'data.npz')
        print(episode)
        data = dict(np.load(epi_data_path))

def downsample_mpm_structure(output_dir):
    n_downsample = 1000
    with open(os.path.join(output_dir, 'num_dict.pkl'), 'rb') as f:
        num_dict = pkl.load(f)
    data_path = os.path.join(output_dir, 'phystwin_data')
    name_l = os.listdir(data_path)
    for name in name_l:
        if name not in ['rope_double_hand']:
            continue
        structure_num = num_dict[name]
        name_data_path = os.path.join(data_path, name)
        episodes_path = os.listdir(name_data_path)
        fps_idx = None
        for episode in tqdm(episodes_path, desc=f'processing {name}'):
            episode_path = os.path.join(name_data_path, episode)
            epi_data_path = os.path.join(episode_path, 'data.npz')
            E_path = os.path.join(episode_path, 'E_downsample.npy')
            # if os.path.exists(E_path):
            #     continue
            data = dict(np.load(epi_data_path))
            xyz = data['pos'][:, :structure_num]
            xyz_tensor = torch.from_numpy(xyz).float()
            if fps_idx is None:
                fps_idx = farthest_point_sampler(xyz_tensor[0:1], n_downsample, start_idx=0)[0]
            # np.save(f'/data/dev/gs-dyn/mpm_data/fps_1000_idx/{name}_fps_idx.npy', fps_idx)
            xyz_tensor = xyz_tensor[:, fps_idx]
            xyz = xyz_tensor.numpy()

            eef_pos = data['eef_pos']
            if name in ['single_push_rope', 'single_push_rope_1', 'single_push_rope_4']:
                eef_pos[:, -1] = xyz[0, :, -1].mean()
            fps_E = data['E'][fps_idx]
            # # trajectory smoothing
            # for _ in range(10):
            #     xyz[1:-1] = (xyz[:-2] + xyz[1:-1] + xyz[2:]) / 3
            friction = data['friction']

            # save_path = '/data/dev/gs-dyn/phystwin_data/save_path'
            # save_path = os.path.join(save_path, episode)
            # os.makedirs(save_path, exist_ok=True)
            #
            # np.save(os.path.join(save_path, f"pos_downsampled.npy"), xyz)
            # np.save(os.path.join(save_path, f"eef_pos_downsampled.npy"), eef_pos)
            # np.save(os.path.join(save_path, f"E_downsample.npy"), fps_E)
            # np.save(os.path.join(save_path, f"friction.npy"), friction)
            np.save(os.path.join(episode_path, f"fps_idx_pre.npy"), fps_idx)
            np.save(os.path.join(episode_path, f"pos_downsampled.npy"), xyz)
            np.save(os.path.join(episode_path, f"eef_pos_downsampled.npy"), eef_pos)
            np.save(os.path.join(episode_path, f"E_downsample.npy"), fps_E)
            np.save(os.path.join(episode_path, f"friction.npy"), friction)


if __name__ == "__main__":
    # gen_data_path = '../gen_data'
    # downsample_only(gen_data_path)

    gen_data_path = '../mpm_data'
    downsample_mpm_structure(gen_data_path)
    # downsample_only_gripper(gen_data_path)
    # print_bad_npz(output_dir=gen_data_path)
