import torch as th
import numpy as np
from typing import Optional, Tuple
import datetime, time
import os
import h5py
import open3d as o3d
import matplotlib.pyplot as plt

def pcd_vis(pc):
    # visualize with open3D
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pc.reshape(-1, 3)) 
    axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0])
    o3d.visualization.draw_geometries([pcd, axis])
    print('number points', pc.shape[0])


def color_pcd_vis(color_pcd, name=None):
    # visualize with open3D
    pcd = o3d.geometry.PointCloud()
    pcd.colors = o3d.utility.Vector3dVector(color_pcd[:, :3])
    pcd.points = o3d.utility.Vector3dVector(color_pcd[:,3:]) 
    axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3, origin=[0, 0, 0])
    vis = o3d.visualization.draw_geometries([pcd, axis])
    vis.capture_screen_image(f"color_pcd_{name}.png")
    print('number points', color_pcd.shape[0])


def plot_frame_from_matrix(ax, T_matrix):

    # Convert quaternion [x, y, z, w] to rotation matrix
    # rot = R.from_quat(quaternion)  # scipy expects [x, y, z, w]
    # rot_matrix = rot.as_matrix()   # 3x3 rotation matrix
    assert T_matrix.shape == (4, 4), "T must be a 4x4 matrix"
    origin = T_matrix[:3, 3]  # translation
    rot_matrix = T_matrix[:3, :3]      # rotation matrix

    length = 1
    # Define unit axes and transform them
    x_axis = rot_matrix[:, 0] * length
    y_axis = rot_matrix[:, 1] * length
    z_axis = rot_matrix[:, 2] * length

    # Plot arrows for the axes
    ax.quiver(*origin, *x_axis, color='r', label='world_to_cam' + ' X')
    ax.quiver(*origin, *y_axis, color='g', label='world_to_cam' + ' Y')
    ax.quiver(*origin, *z_axis, color='b', label='world_to_cam' + ' Z')



def plot_pcd_with_matplotlib(pcd, name=None):
    # plot the points in 3d
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(pcd[:, 0], pcd[:, 1], pcd[:, 2], c='r', marker='o', s=0.1)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    def set_axes_equal(ax):
        '''Set 3D plot axes to equal scale so that spheres appear as spheres.'''
        x_limits = ax.get_xlim3d()
        y_limits = ax.get_ylim3d()
        z_limits = ax.get_zlim3d()

        x_range = x_limits[1] - x_limits[0]
        y_range = y_limits[1] - y_limits[0]
        z_range = z_limits[1] - z_limits[0]
        max_range = max(x_range, y_range, z_range)

        x_middle = np.mean(x_limits)
        y_middle = np.mean(y_limits)
        z_middle = np.mean(z_limits)

        ax.set_xlim([x_middle - max_range/2, x_middle + max_range/2])
        ax.set_ylim([y_middle - max_range/2, y_middle + max_range/2])
        ax.set_zlim([z_middle - max_range/2, z_middle + max_range/2])
    # set_axes_equal(ax)
    plt.xlim(-2, 2)
    plt.ylim(-2, 2)

    # plt.zlim(-2, 2)
    ax.set_zlim(-2, 2)
    # # make the x y z axis equal
    ax.set_box_aspect([1,1,1])  # aspect ratio is 1:1:1
    # plt.show()
    plt.savefig(f"./vis/color_pcd_{name}.png")
    plt.close()


def plot_traj(
        eye_pose, 
        left_cam_pose, 
        right_cam_pose,
        base_link_pose,
        left_eef_pose,
        right_eef_pose,
        ):

    # plot eye pose frame info
    fig  = plt.figure()
    fig.set_size_inches(20, 20)
    ax = fig.add_subplot(111, projection='3d')
    for i in range(eye_pose.shape[0]):
        if i % 40 == 0:
            # plot eye pose
            pos = eye_pose[i, :3]
            quat = eye_pose[i, 3:]
            T_matrix = pose2mat((th.from_numpy(pos), th.from_numpy(quat))).numpy()
            plot_frame_from_matrix(ax, T_matrix)
            ax.text(eye_pose[i, 0], eye_pose[i, 1], eye_pose[i, 2], str(i), color='magenta')

            # plot base link pose
            pos = base_link_pose[i, :3]
            quat = base_link_pose[i, 3:]
            T_matrix = pose2mat((th.from_numpy(pos), th.from_numpy(quat))).numpy()
            plot_frame_from_matrix(ax, T_matrix)
            ax.text(base_link_pose[i, 0], base_link_pose[i, 1], base_link_pose[i, 2], str(i), color='black')
                                 
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('eye camera Frame and base link pose frame')
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_zlim(-0., 3.0)
    ax.set_box_aspect([1,1,1])  # aspect ratio is 1:1:1
    plt.savefig('eye_camera_frame.png')
    plt.show()



def write_to_hdf5_per_demo(dict, output_path, force_output_path=False):
    # construct a tmp folder
    timestamp = time.time()
    if not force_output_path:
        folder_path = os.path.join(os.path.dirname(output_path), "tmp")
        os.makedirs(folder_path, exist_ok=True)
        time_str = datetime.datetime.fromtimestamp(timestamp).strftime('processed_date_%m_%d_%Y_time_%H_%M_%S')
        output_path = os.path.join(folder_path, "{}.hdf5".format(time_str))
    
    with h5py.File(output_path, "w") as f:
        data_group = f.create_group("data")
        for demo_key in dict.keys():
            demo = dict[demo_key]
            demo_group = data_group.create_group("demo_0")
            for key in demo.keys():
                if key in ["obs", "next_obs"]:
                    obs_group = demo_group.create_group(key)
                    obs = demo[key]
                    for key in obs.keys():
                        obs_group.create_dataset(key, data=obs[key])
                else:
                    demo_group.create_dataset(key, data=demo[key])
        data_group.attrs["total"] = demo_group["actions"].shape[0]
        data_group.attrs["num_samples"] = demo_group["actions"].shape[0]
        data_group.attrs["timestamp"] = timestamp
    print('Finished writing the data to', output_path)



def mat2quat(rmat: th.Tensor) -> th.Tensor:
    """
    Converts given rotation matrix to quaternion.
    Args:
        rmat (torch.Tensor): (3, 3) or (..., 3, 3) rotation matrix
    Returns:
        torch.Tensor: (4,) or (..., 4) (x,y,z,w) float quaternion angles
    """
    # Check if input is a single matrix or a batch
    is_single = rmat.dim() == 2
    if is_single:
        rmat = rmat.unsqueeze(0)

    batch_shape = rmat.shape[:-2]
    mat_flat = rmat.reshape(-1, 3, 3)

    m00, m01, m02 = mat_flat[:, 0, 0], mat_flat[:, 0, 1], mat_flat[:, 0, 2]
    m10, m11, m12 = mat_flat[:, 1, 0], mat_flat[:, 1, 1], mat_flat[:, 1, 2]
    m20, m21, m22 = mat_flat[:, 2, 0], mat_flat[:, 2, 1], mat_flat[:, 2, 2]

    trace = m00 + m11 + m22

    trace_positive = trace > 0
    cond1 = (m00 > m11) & (m00 > m22) & ~trace_positive
    cond2 = (m11 > m22) & ~(trace_positive | cond1)
    cond3 = ~(trace_positive | cond1 | cond2)

    # Trace positive condition
    sq = th.where(trace_positive, th.sqrt(trace + 1.0) * 2.0, th.zeros_like(trace))
    qw = th.where(trace_positive, 0.25 * sq, th.zeros_like(trace))
    qx = th.where(trace_positive, (m21 - m12) / sq, th.zeros_like(trace))
    qy = th.where(trace_positive, (m02 - m20) / sq, th.zeros_like(trace))
    qz = th.where(trace_positive, (m10 - m01) / sq, th.zeros_like(trace))

    # Condition 1
    sq = th.where(cond1, th.sqrt(1.0 + m00 - m11 - m22) * 2.0, sq)
    qw = th.where(cond1, (m21 - m12) / sq, qw)
    qx = th.where(cond1, 0.25 * sq, qx)
    qy = th.where(cond1, (m01 + m10) / sq, qy)
    qz = th.where(cond1, (m02 + m20) / sq, qz)

    # Condition 2
    sq = th.where(cond2, th.sqrt(1.0 + m11 - m00 - m22) * 2.0, sq)
    qw = th.where(cond2, (m02 - m20) / sq, qw)
    qx = th.where(cond2, (m01 + m10) / sq, qx)
    qy = th.where(cond2, 0.25 * sq, qy)
    qz = th.where(cond2, (m12 + m21) / sq, qz)

    # Condition 3
    sq = th.where(cond3, th.sqrt(1.0 + m22 - m00 - m11) * 2.0, sq)
    qw = th.where(cond3, (m10 - m01) / sq, qw)
    qx = th.where(cond3, (m02 + m20) / sq, qx)
    qy = th.where(cond3, (m12 + m21) / sq, qy)
    qz = th.where(cond3, 0.25 * sq, qz)

    quat = th.stack([qx, qy, qz, qw], dim=-1)

    # Normalize the quaternion
    quat = quat / th.norm(quat, dim=-1, keepdim=True)

    # Reshape to match input batch shape
    quat = quat.reshape(batch_shape + (4,))

    # If input was a single matrix, remove the batch dimension
    if is_single:
        quat = quat.squeeze(0)

    return quat


def mat2pose(hmat):
    """
    Converts a homogeneous 4x4 matrix into pose.

    Args:
        hmat (torch.tensor): a 4x4 homogeneous matrix

    Returns:
        2-tuple:
            - (torch.tensor) (x,y,z) position array in cartesian coordinates
            - (torch.tensor) (x,y,z,w) orientation array in quaternion form
    """
    hmat = hmat.reshape(-1, 4, 4)
    assert th.allclose(hmat[:, :3, :3].det(), th.tensor(1.0)), "Rotation matrix must not be scaled"
    pos = hmat[:, :3, 3]
    orn = mat2quat(hmat[:, :3, :3])
    pos = pos.squeeze(0)
    orn = orn.squeeze(0)
    return pos, orn


def quat2mat(quaternion):
    """
    Convert quaternions into rotation matrices.

    Args:
        quaternion (torch.Tensor): A tensor of shape (..., 4) representing batches of quaternions (w, x, y, z).

    Returns:
        torch.Tensor: A tensor of shape (..., 3, 3) representing batches of rotation matrices.
    """
    quaternion = quaternion / th.norm(quaternion, dim=-1, keepdim=True)

    outer = quaternion.unsqueeze(-1) * quaternion.unsqueeze(-2)

    # Extract the necessary components
    xx = outer[..., 0, 0]
    yy = outer[..., 1, 1]
    zz = outer[..., 2, 2]
    xy = outer[..., 0, 1]
    xz = outer[..., 0, 2]
    yz = outer[..., 1, 2]
    xw = outer[..., 0, 3]
    yw = outer[..., 1, 3]
    zw = outer[..., 2, 3]

    rmat = th.empty(quaternion.shape[:-1] + (3, 3), dtype=quaternion.dtype, device=quaternion.device)

    rmat[..., 0, 0] = 1 - 2 * (yy + zz)
    rmat[..., 0, 1] = 2 * (xy - zw)
    rmat[..., 0, 2] = 2 * (xz + yw)

    rmat[..., 1, 0] = 2 * (xy + zw)
    rmat[..., 1, 1] = 1 - 2 * (xx + zz)
    rmat[..., 1, 2] = 2 * (yz - xw)

    rmat[..., 2, 0] = 2 * (xz - yw)
    rmat[..., 2, 1] = 2 * (yz + xw)
    rmat[..., 2, 2] = 1 - 2 * (xx + yy)

    return rmat


def pose2mat(pose: Tuple[th.Tensor, th.Tensor]) -> th.Tensor:
    pos, orn = pose

    # Ensure pos and orn are the expected shape and dtype
    pos = pos.to(dtype=th.float32).reshape(-1, 3)
    orn = orn.to(dtype=th.float32).reshape(-1, 4)

    batch_size = pos.shape[0]
    homo_pose_mat = th.eye(4, dtype=th.float32).unsqueeze(0).repeat(batch_size, 1, 1)

    homo_pose_mat[:, :3, :3] = quat2mat(orn)
    homo_pose_mat[:, :3, 3] = pos

    homo_pose_mat = homo_pose_mat.squeeze(0)

    return homo_pose_mat


def pose_transform(pos1, quat1, pos0, quat0):
    """
    Conducts forward transform from pose (pos0, quat0) to pose (pos1, quat1):

    pose1 @ pose0, NOT pose0 @ pose1

    Args:
        pos1: (x,y,z) position to transform
        quat1: (x,y,z,w) orientation to transform
        pos0: (x,y,z) initial position
        quat0: (x,y,z,w) initial orientation

    Returns:
        2-tuple:
            - (torch.tensor) (x,y,z) position array in cartesian coordinates
            - (torch.tensor) (x,y,z,w) orientation array in quaternion form
    """
    # Get poses
    mat0 = pose2mat((pos0, quat0))
    mat1 = pose2mat((pos1, quat1))

    # Multiply and convert back to pos, quat
    return mat2pose(mat1 @ mat0)

