import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from render.dynamics_module import DynamicsModule
import yaml
from plyfile import PlyData, PlyElement
from render.renderer import Renderer
from render.phystwin_LBS import my_get_camera_view_phystwin
from E_mlp import PositionalEncodingMLP, PositionalEncodingMLPBatch
import pickle as pkl
import cv2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def plot_3d_top_view_equal_scale(points, kp_gt, eef=None, rels=None, max_nobj=100, output_file='3d_top_view_equal_scale.png'):
    points = points.detach().cpu().numpy()
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # 提取坐标
    x, y, z = points[:, 0], points[:, 1], points[:, 2]
    gt_x, gt_y, gt_z = kp_gt[:, 0], kp_gt[:, 1], kp_gt[:, 2]
    if eef is not None:
        eef_x, eef_y, eef_z = eef[:, 0], eef[:, 1], eef[:, 2]

    # 绘制散点
    ax.scatter(x[:-2], y[:-2], z[:-2], c='b', marker='o', s=10, alpha=0.6, label='predict')  # 蓝色
    # ax.scatter(x[-2:], y[-2:], z[-2:], c='g', marker='o', s=10, alpha=0.6, label='predict')  # 蓝色
    if eef is not None:
        ax.scatter(eef_x, eef_y, eef_z, c='g', marker='o', s=10, alpha=0.6, label='controller')  # 绿色
    # ax.scatter(gt_x, gt_y, gt_z, c='r', marker='o', s=10, alpha=0.6, label='GT')  # 红色

    # 绘制边（如果提供了rels）
    if rels is not None:
        for (i, j) in rels:
            # 确保索引在有效范围内
            if i < len(points) and j < len(points):
                ax.plot([points[i, 0], points[j, 0]],
                        [points[i, 1], points[j, 1]],
                        [points[i, 2], points[j, 2]],
                        c='gray', alpha=0.3, linewidth=0.5)
            else:
                p_1 = points[i] if i < len(points) else eef[i-max_nobj]
                p_2 = points[j] if j < len(points) else eef[j-max_nobj]
                ax.plot([p_1[0], p_2[0]],
                        [p_1[1], p_2[1]],
                        [p_1[2], p_2[2]],
                        c='gray', alpha=0.3, linewidth=0.5)

    # 强制设置坐标轴范围为 [0, 1]
    ax.set_xlim(-0.5, 0.5)
    ax.set_ylim(-0.5, 0.5)
    ax.set_zlim(-0.5, 0.5)

    # 设置坐标轴标签和标题
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Top View with Edges' if rels is not None else '3D Top View')
    ax.legend()

    # 关键1：固定X/Y/Z轴比例（等比例缩放）
    ax.set_box_aspect([1, 1, 1])  # 保证X/Y/Z轴1:1:1

    # 关键2：固定视角为Z轴俯视图（正上方）
    ax.view_init(elev=-130, azim=2)  # elev=90（正上方），azim=0（X轴朝右）

    # 保存图片
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"图片已保存为: {output_file}")

def load_3dgs_model(ply_path):

    # 读取 PLY 文件
    plydata = PlyData.read(ply_path)
    max_sh_degree = 3

    # 提取 PLY 文件的点云数据
    xyz_ply = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])), axis=1)
    print(f"original 3dgs num: {xyz_ply.shape[0]}")
    opacities_ply = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

    features_dc_ply = np.zeros((xyz_ply.shape[0], 3, 1))
    features_dc_ply[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
    features_dc_ply[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
    features_dc_ply[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

    extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
    extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split('_')[-1]))
    assert len(extra_f_names) == 3 * (max_sh_degree + 1) ** 2 - 3
    features_extra_ply = np.zeros((xyz_ply.shape[0], len(extra_f_names)))
    for idx, attr_name in enumerate(extra_f_names):
        features_extra_ply[:, idx] = np.asarray(plydata.elements[0][attr_name])
    features_extra_ply = features_extra_ply.reshape(
        (features_extra_ply.shape[0], 3, (max_sh_degree + 1) ** 2 - 1))

    scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
    scale_names = sorted(scale_names, key=lambda x: int(x.split('_')[-1]))
    scales_ply = np.zeros((xyz_ply.shape[0], len(scale_names)))
    for idx, attr_name in enumerate(scale_names):
        scales_ply[:, idx] = np.asarray(plydata.elements[0][attr_name])
    if scales_ply.shape[1] == 1:
        scales_ply = np.tile(scales_ply, (1, 3))  # 将 scales_ply 复制为 (N, 3)

    rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
    rot_names = sorted(rot_names, key=lambda x: int(x.split('_')[-1]))
    rots_ply = np.zeros((xyz_ply.shape[0], len(rot_names)))
    for idx, attr_name in enumerate(rot_names):
        rots_ply[:, idx] = np.asarray(plydata.elements[0][attr_name])

    xyz_ply = torch.tensor(xyz_ply, dtype=torch.float, device="cuda")
    features_dc_ply = torch.tensor(features_dc_ply, dtype=torch.float, device="cuda")
    normalization = 1.0 / np.sqrt(4 * np.pi)
    precomp_colors = features_dc_ply[:, :, 0] * normalization + 0.5
    features_extra_ply = torch.tensor(features_extra_ply, dtype=torch.float, device="cuda")
    opacities_ply = torch.tensor(opacities_ply, dtype=torch.float, device="cuda")
    opacities_ply = torch.sigmoid(opacities_ply)
    scales_ply = torch.tensor(scales_ply, dtype=torch.float, device="cuda")
    scales_ply = torch.exp(scales_ply)
    rots_ply = torch.tensor(rots_ply, dtype=torch.float, device="cuda")
    gs_params = {
        'xyz_ply': xyz_ply,
        'features_dc_ply': features_dc_ply,
        'features_extra_ply': features_extra_ply,
        'opacities_ply': opacities_ply,
        'scales_ply': scales_ply,
        'rots_ply': rots_ply,
        'precomp_colors': precomp_colors,
    }
    return gs_params

def draw_points(im, w2c, k, points, color=(0, 255, 0), radius=2):
    # Project xyz_gt to 2D image coordinates
    xyz_gt_t = points.cpu().numpy() if isinstance(points, torch.Tensor) else points
    xyz_gt_hom = np.hstack((xyz_gt_t, np.ones((xyz_gt_t.shape[0], 1))))  # Convert to homogeneous coordinates

    # Transform to camera coordinates
    cam_coords = (w2c @ xyz_gt_hom.T).T

    # Project to image plane
    fx, fy, cx, cy = k[0, 0], k[1, 1], k[0, 2], k[1, 2]
    img_coords = []
    for pt in cam_coords:
        if pt[2] > 0:  # Only project points in front of camera
            u = fx * (pt[0] / pt[2]) + cx
            v = fy * (pt[1] / pt[2]) + cy
            img_coords.append((int(u), int(v)))

    # Draw projected points on the image
    for pt in img_coords:
        cv2.circle(im, pt, radius=radius, color=color, thickness=-1)  # Green dots for xyz_gt