import os
import glob
import json

import torch
import numpy as np
import pickle as pkl
def params2ply(self, path, n=None):
    params = np.load(path)

    xyz = params['object_points'][0][:n] if n is None else n
    opacities = params['logit_opacities'][:n] if n is None else n
    # features_dc = np.zeros((xyz.shape[0], 3, 1))
    # features_extra = np.zeros((xyz.shape[0], 3, 1))
    # features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
    # features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
    # features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

    # extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
    # extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
    # assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
    # features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
    # for idx, attr_name in enumerate(extra_f_names):
    #     features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
    # # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
    # features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))

    # scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
    # scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
    # scales = np.zeros((xyz.shape[0], len(scale_names)))
    # for idx, attr_name in enumerate(scale_names):
    #     scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
    scales = params['log_scales'][:n] if n is None else n
    rots = params['unnorm_rotations'][0][:n] if n is None else n
    precomp_colors = params['object_colors'][0][:n] if n is None else n
    num_coeffs = (self.max_sh_degree + 1) ** 2
    normalization = 1.0 / np.sqrt(4 * np.pi)
    features_dc = np.zeros((xyz.shape[0], 3, 1))
    features_dc[:, :, 0] = (precomp_colors - 0.5) / normalization
    features_extra = np.zeros((xyz.shape[0], 3, num_coeffs -1))

def sample_rectangle_boundary_uniform(x1, y1, x2, y2, num_points):
    # 计算矩形的宽度和高度
    width = x2 - x1
    height = y2 - y1

    # 每条边上的采样点数（总点数除以4）
    points_per_edge = num_points // 4
    points_width = num_points // (2*width + 2*height) * width
    points_height = num_points // (2*height + 2*width) * height


    points = []

    # 1. 下边 (从 (x1, y1) 到 (x2, y1))
    for i in range(points_width):
        x = x1 + i * (width / (points_width - 1))  # 计算均匀间隔
        y = y1
        points.append((x, y))

    # 2. 右边 (从 (x2, y1) 到 (x2, y2))
    for i in range(points_height):
        x = x2
        y = y1 + i * (height / (points_height - 1))  # 计算均匀间隔
        points.append((x, y))

    # 3. 上边 (从 (x2, y2) 到 (x1, y2))
    for i in range(points_width):
        x = x2 - i * (width / (points_width - 1))  # 计算均匀间隔
        y = y2
        points.append((x, y))

    # 4. 左边 (从 (x1, y2) 到 (x1, y1))
    for i in range(points_height):
        x = x1
        y = y2 - i * (height / (points_height - 1))  # 计算均匀间隔
        points.append((x, y))

    return np.array(points)

x1, y1 = 0, 0  # 左下角坐标
x2, y2 = 2, 1  # 右上角坐标
# 采样100个点
sampled_points = sample_rectangle_boundary_uniform(x1, y1, x2, y2, 100)


def random_sample_on_rectangle_boundary_one(x1, y1, x2, y2, span=1.1):
    x1, y1, x2, y2 = x1*span, y1*span, x2*span, y2*span
    # 随机选择一个边：0 - 下边，1 - 右边，2 - 上边，3 - 左边
    edge = np.random.randint(0, 4)

    # 根据选择的边来生成随机点
    if edge == 0:  # 下边 (从 (x1, y1) 到 (x2, y1))
        x = np.random.uniform(x1, x2)
        y = y1
    elif edge == 1:  # 右边 (从 (x2, y1) 到 (x2, y2))
        x = x2
        y = np.random.uniform(y1, y2)
    elif edge == 2:  # 上边 (从 (x2, y2) 到 (x1, y2))
        x = np.random.uniform(x1, x2)
        y = y2
    else:  # 左边 (从 (x1, y2) 到 (x1, y1))
        x = x1
        y = np.random.uniform(y1, y2)

    return np.array([x, y])

def rpy_to_rotation_matrix(roll, pitch, yaw):
    # Assume the input in in degree
    roll = roll / 180 * np.pi
    pitch = pitch / 180 * np.pi
    yaw = yaw / 180 * np.pi
    # Define the rotation matrices
    Rx = np.array([[1, 0, 0], [0, np.cos(roll), -np.sin(roll)], [0, np.sin(roll), np.cos(roll)]])
    Ry = np.array([[np.cos(pitch), 0, np.sin(pitch)], [0, 1, 0], [-np.sin(pitch), 0, np.cos(pitch)]])
    Rz = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]])
    # Combine the rotations
    R = Rz @ Ry @ Rx
    return R

def smooth_positions_multi(positions, window_size=7):
    if len(positions.shape) == 2:
        pad_size = window_size // 2
        padded_positions = np.pad(positions, ((pad_size, pad_size), (0, 0)), mode='edge')
        smoothed_positions = np.zeros_like(positions)
        for i in range(len(positions)):
            smoothed_positions[i] = np.mean(padded_positions[i:i + window_size], axis=0)
        return smoothed_positions
    elif len(positions.shape) == 3:
        # 处理多组点的情况
        num_groups = positions.shape[0]
        smoothed_positions = np.zeros_like(positions)
        for group in range(num_groups):
            group_positions = positions[group]
            pad_size = window_size // 2
            padded_positions = np.pad(group_positions, ((pad_size, pad_size), (0, 0)), mode='edge')
            for i in range(len(group_positions)):
                smoothed_positions[group, i] = np.mean(padded_positions[i:i + window_size], axis=0)
        return smoothed_positions
    
def smooth_positions_torch(positions, window_size=7):
    positions = positions.detach().cpu().numpy()
    pad_size = window_size // 2
    padded_positions = np.pad(positions, ((pad_size, pad_size), (0, 0)), mode='edge')
    smoothed_positions = np.zeros_like(positions)
    for i in range(len(positions)):
        smoothed_positions[i] = np.mean(padded_positions[i:i + window_size], axis=0)
    smoothed_positions = torch.tensor(smoothed_positions, device='cuda', dtype=torch.float32)
    return smoothed_positions

def load_split(data_dir):
    with open(f"{data_dir}/split.json", "r") as f:
        split = json.load(f)
    frame_len = split["frame_len"]
    train_frame = split["train"][1]
    test_frame = split["test"][1]
    return frame_len,train_frame,test_frame

def load_eef_pos(data_dir, gripper_type, eef_k_shift=None):
    data_path = os.path.join(data_dir,"final_data.pkl")
    with open(data_path, "rb") as f:
        data = pkl.load(f)

    object_points = data["object_points"]
    controller_points = data["controller_points"]
    object_points = np.array(object_points, dtype=np.float32)
    controller_points = np.array(controller_points, dtype=np.float32)

    first_frame_object_points = object_points[0].reshape(-1, 3)
    first_frame_controller_points = controller_points[0].reshape(-1, 3)

    distances = []
    for ctrl_point in first_frame_controller_points:
        dist = np.min(np.linalg.norm(first_frame_object_points - ctrl_point, axis=1))
        distances.append(dist)
    distances = np.array(distances)

    sorted_indices = np.argsort(distances)

    if eef_k_shift is None:
        eef_k_shift = [0, 0, 0]

    if gripper_type == 'single_gripper':
        controller_points = controller_points[:, sorted_indices[0], :]
        controller_points = smooth_positions_multi(controller_points)
        controller_points += np.array(eef_k_shift)
    elif gripper_type == 'push':
        controller_points = controller_points[:, sorted_indices[0], :]
        controller_points = smooth_positions_multi(controller_points)
        controller_points += np.array(eef_k_shift)
    elif gripper_type == 'double_gripper':
        controller_points = np.stack([controller_points[:, sorted_indices[0]], controller_points[:, sorted_indices[1]]])
        controller_points = smooth_positions_multi(controller_points)
        eef_k_shift = np.array(eef_k_shift)
        controller_points += eef_k_shift.reshape(1, 1, 3)
    controller_points = np.squeeze(controller_points)
    return controller_points

def load_track_data(track_path, filter_num=None, device='cuda'):
    track_dir = os.path.join(track_path, 'final_data.pkl')
    with open(track_dir, 'rb') as f:
        track = pkl.load(f)
    # track数据没有gs属性，创建gs属性
    if 'object_points' in track:
        num_frames, num_points, _ = track['object_points'].shape
        track['seg_colors'] = np.full((num_points, 3),[0,1,0])
        track['unnorm_rotations'] = np.ones((num_frames, num_points, 4))
        track['logit_opacities'] = np.ones((num_points, 1))
        track['log_scales'] = np.full((num_points, 3), -7)

    return track

def make_points_over_collider(tensor, table_height, offset=0.000001):
    z_values = tensor[:, -1]
    below_table_indices = z_values > table_height
    z_values[below_table_indices] = table_height - offset
    tensor[:, 2] = z_values
    return tensor
