import os
import numpy as np
import open3d as o3d
from scipy.spatial.transform import Rotation as R
from plyfile import PlyData, PlyElement

from utils.data_utils import load_pose_txt

def save_pred_txt(pred, out_dir):
    start_pred = pred[:14]
    end_pred = pred[14:]

    start_path = os.path.join(out_dir, "start_pos_predict.txt")
    end_path = os.path.join(out_dir, "end_pos_predict.txt")

    with open(start_path, "w") as f:
        f.write("Left: \n")
        for v in start_pred[:7]:
            f.write(f"{v}\n")
        f.write("Right: \n")
        for v in start_pred[7:]:
            f.write(f"{v}\n")

    with open(end_path, "w") as f:
        f.write("Left: \n")
        for v in end_pred[:7]:
            f.write(f"{v}\n")
        f.write("Right: \n")
        for v in end_pred[7:]:
            f.write(f"{v}\n")

def make_parallel_gripper_mesh(color=[1, 0, 0]):
    w_body, h_body, d_body = 0.04, 0.015, 0.06
    r_finger, l_finger = 0.008, 0.08
    max_open = 0.03

    body = o3d.geometry.TriangleMesh.create_box(w_body, h_body, d_body)
    body.translate([-w_body/2, 0, -d_body/2])

    left = o3d.geometry.TriangleMesh.create_cylinder(r_finger, l_finger)
    left.rotate(left.get_rotation_matrix_from_xyz((0, 0, np.pi/2)))
    left.translate([-max_open, l_finger/2, 0])

    right = o3d.geometry.TriangleMesh.create_cylinder(r_finger, l_finger)
    right.rotate(right.get_rotation_matrix_from_xyz((0, 0, np.pi/2)))
    right.translate([max_open, l_finger/2, 0])

    gripper = body + left + right
    gripper.paint_uniform_color(color)
    return gripper

def place_gripper(mesh, pose):
    x, y, z, qx, qy, qz, qw = pose
    norm = np.sqrt(qx*qx + qy*qy + qz*qz + qw*qw)
    if norm < 1e-6:
        qx, qy, qz, qw = 0.0, 0.0, 0.0, 1.0
    else:
        qx, qy, qz, qw = qx/norm, qy/norm, qz/norm, qw/norm

    M = o3d.geometry.TriangleMesh(o3d.geometry.TriangleMesh(mesh))
    rot = R.from_quat([qx, qy, qz, qw]).as_matrix()
    M.rotate(rot, center=[0, 0, 0])
    M.translate([x, y, z])
    return M

def mesh_to_points(mesh, density=3000):
    pc = mesh.sample_points_uniformly(number_of_points=density)
    return np.asarray(pc.points)

def save_prediction_ply(pred_pose, save_path, init_ply_path):
    COLOR_GT = np.array([0, 1, 0])
    COLOR_PR = np.array([1, 1, 0])
    COLOR_ERR = np.array([1, 0, 0])
    COLOR_TRAJ = np.array([0, 0, 1])

    ply = PlyData.read(init_ply_path)
    vertex = ply["vertex"]
    bg_xyz = np.vstack([vertex["x"], vertex["y"], vertex["z"]]).T.astype(np.float32)

    if {"red", "green", "blue"}.issubset(vertex.data.dtype.names):
        bg_colors = np.vstack([vertex["red"],
                               vertex["green"],
                               vertex["blue"]]).T.astype(np.uint8)
    else:
        bg_colors = np.full((bg_xyz.shape[0], 3), 255, dtype=np.uint8)

    base_dir = os.path.dirname(save_path)
    gt_start = load_pose_txt(os.path.join(base_dir, "start_pos.txt"))
    gt_end = load_pose_txt(os.path.join(base_dir, "end_pos.txt"))

    gt_start_L = gt_start[0:7]
    gt_start_R = gt_start[7:14]
    gt_end_L = gt_end[0:7]
    gt_end_R = gt_end[7:14]

    pr_start_L = pred_pose[0:7]
    pr_start_R = pred_pose[7:14]
    pr_end_L = pred_pose[14:21]
    pr_end_R = pred_pose[21:28]

    def sample_meshes(mesh_list, color):
        pts_all = []
        col_all = []
        for m in mesh_list:
            pts = mesh_to_points(m, density=1500)
            col = np.tile((color * 255).astype(np.uint8), (pts.shape[0], 1))
            pts_all.append(pts)
            col_all.append(col)
        return np.vstack(pts_all), np.vstack(col_all)

    gt_meshes = [
        place_gripper(make_parallel_gripper_mesh(COLOR_GT), gt_start_L),
        place_gripper(make_parallel_gripper_mesh(COLOR_GT), gt_start_R),
        place_gripper(make_parallel_gripper_mesh(COLOR_GT), gt_end_L),
        place_gripper(make_parallel_gripper_mesh(COLOR_GT), gt_end_R),
    ]
    gt_pts, gt_cols = sample_meshes(gt_meshes, COLOR_GT)

    pr_meshes = [
        place_gripper(make_parallel_gripper_mesh(COLOR_PR), pr_start_L),
        place_gripper(make_parallel_gripper_mesh(COLOR_PR), pr_start_R),
        place_gripper(make_parallel_gripper_mesh(COLOR_PR), pr_end_L),
        place_gripper(make_parallel_gripper_mesh(COLOR_PR), pr_end_R),
    ]
    pr_pts, pr_cols = sample_meshes(pr_meshes, COLOR_PR)

    def create_arrow(p0, p1, num=100):
        pts = np.linspace(p0, p1, num)
        cols = np.tile((COLOR_ERR * 255).astype(np.uint8), (num, 1))
        return pts, cols

    pairs = [
        (gt_start_L[:3], pr_start_L[:3]),
        (gt_start_R[:3], pr_start_R[:3]),
        (gt_end_L[:3], pr_end_L[:3]),
        (gt_end_R[:3], pr_end_R[:3])
    ]

    err_pts = []
    err_cols = []
    for p0, p1 in pairs:
        pts, col = create_arrow(p0, p1)
        err_pts.append(pts)
        err_cols.append(col)
    err_pts = np.vstack(err_pts)
    err_cols = np.vstack(err_cols)

    def create_line(p0, p1, num=100):
        pts = np.linspace(p0, p1, num)
        cols = np.tile((COLOR_TRAJ * 255).astype(np.uint8), (num, 1))
        return pts, cols

    traj_pairs = [
        (gt_start_L[:3], gt_end_L[:3]),
        (gt_start_R[:3], gt_end_R[:3]),
        (pr_start_L[:3], pr_end_L[:3]),
        (pr_start_R[:3], pr_end_R[:3])
    ]

    traj_pts = []
    traj_cols = []
    for p0, p1 in traj_pairs:
        pts, col = create_line(p0, p1)
        traj_pts.append(pts)
        traj_cols.append(col)

    traj_pts = np.vstack(traj_pts)
    traj_cols = np.vstack(traj_cols)

    all_points = np.vstack([
        bg_xyz, gt_pts, pr_pts, err_pts, traj_pts
    ])
    all_colors = np.vstack([
        bg_colors, gt_cols, pr_cols, err_cols, traj_cols
    ])

    vertex_data = np.empty(
        len(all_points),
        dtype=[
            ("x", "f4"), ("y", "f4"), ("z", "f4"),
            ("red", "u1"), ("green", "u1"), ("blue", "u1")
        ]
    )

    vertex_data["x"] = all_points[:, 0]
    vertex_data["y"] = all_points[:, 1]
    vertex_data["z"] = all_points[:, 2]
    vertex_data["red"] = all_colors[:, 0]
    vertex_data["green"] = all_colors[:, 1]
    vertex_data["blue"] = all_colors[:, 2]

    PlyData([PlyElement.describe(vertex_data, "vertex")], text=False).write(save_path)