import robotic as ry
import numpy as np
import os
import matplotlib.pyplot as plt
import time
import vtamp.environments.bridge.manipulation as manip
import rowan 
from utils import isolate_red_shapes_from_rgb
import cv2

DRAW_METHOD = 1 #  1: draw using spheres, 2: draw using cylinders

SAVE_IMAGE = False
WHITEBOARD_OFFSET = .125  # base offset of whiteboard in cms to table
WHITEBOARD_TILT = 40 # tilt of the whiteboard in degrees
SPHERE_DISTANCE = 0.01 # Desired distance between sphere centers if DRAW_METHOD is 1

C = ry.Config()
C.addFile(ry.raiPath('scenarios/pandaMops.g'))
C.getFrame('l_panda_finger_joint1').setJointState([.01])
C.getFrame("table").setColor([1, 1, 1])


WHITEBOARD_TILT = np.deg2rad(WHITEBOARD_TILT)  # Convert to radians
quaternion_wb = [np.cos(WHITEBOARD_TILT / 2), np.sin(WHITEBOARD_TILT / 2), 0, 0]
table_height = C.getFrame("table").getSize()[2] / 2
whiteboard_thickness = 0.005 
whiteboard_height_adjustment = (.7) * np.sin(WHITEBOARD_TILT) / 2

table_shape = [.64, .48]
C.addFrame("whiteboard", "table")\
    .setColor([1, 1, 1])\
    .setShape(ry.ST.box, [table_shape[0], table_shape[1], whiteboard_thickness])\
    .setRelativePosition([0, .4, table_height + whiteboard_thickness / 2 + whiteboard_height_adjustment + WHITEBOARD_OFFSET])\
    .setQuaternion(quaternion_wb)


PEN_SIZE = [.1, .01]
# Pen with whom the robot will draw
C.addFrame("pen", "l_gripper").setColor([1, 0, 0]).setShape(ry.ST.cylinder, PEN_SIZE).setRelativePosition([0, 0, -.05])
C.addFrame("penEnd", "pen").setRelativePosition([0, 0, -PEN_SIZE[0] / 2])


homeState = C.getJointState()

lines = 0

# ... existing imports and setup ...

lines = 0
spheres_drawn = 0 # Counter for spheres (used for naming frames)

def draw_line(x0_bl, y0_bl, x1_bl, y1_bl):
    """
    Draws a line by projecting the points onto the whiteboard and placing visual markers.
    Uses either cylinders or equidistant spheres based on DRAW_METHOD.
    Assumes input coordinates are relative to the whiteboard's bottom-left corner.

    Args:
        x0_bl, y0_bl (float): Start coordinates relative to the whiteboard *bottom-left corner*.
        x1_bl, y1_bl (float): End coordinates relative to the whiteboard *bottom-left corner*.
    """
    global lines, spheres_drawn, C, WHITEBOARD_TILT, whiteboard_thickness, table_shape, DRAW_METHOD, SPHERE_DISTANCE # Use SPHERE_DISTANCE

    # Get whiteboard info
    whiteboard_frame = C.getFrame("whiteboard")
    whiteboard_pos = whiteboard_frame.getPosition() # World position of the whiteboard's center
    wb_width = table_shape[0]
    wb_height = table_shape[1]

    # Rotation matrix for pitch (rotation around world X-axis)
    cos_tilt = np.cos(WHITEBOARD_TILT)
    sin_tilt = np.sin(WHITEBOARD_TILT)
    R_wb = np.array([
        [1, 0,         0],
        [0, cos_tilt, -sin_tilt],
        [0, sin_tilt,  cos_tilt]
    ])

    # --- Coordinate Conversion ---
    x0_c = x0_bl - wb_width / 2
    y0_c = y0_bl - wb_height / 2
    x1_c = x1_bl - wb_width / 2
    y1_c = y1_bl - wb_height / 2

    # --- Define Points in Local Frame ---
    z_offset = whiteboard_thickness / 2 + 0.001
    local_p0 = np.array([x0_c, y0_c, z_offset])
    local_p1 = np.array([x1_c, y1_c, z_offset])

    # --- Transform to World Frame ---
    world_p0 = whiteboard_pos + R_wb @ local_p0
    world_p1 = whiteboard_pos + R_wb @ local_p1

    # --- Calculate Line Properties ---
    vector = world_p1 - world_p0
    length = np.linalg.norm(vector)

    if length < 1e-6: # Avoid division by zero or tiny markers
        print("Warning: Skipping near-zero length line.")
        return

    # --- Add Markers ---
    if DRAW_METHOD == 1: # Draw with Spheres
        if SPHERE_DISTANCE <= 1e-6:
             print("Warning: SPHERE_DISTANCE must be positive to draw spheres.")
             return

        # Calculate the number of spheres needed based on length and distance
        # Add 1 to include both start and end points
        num_spheres = int(np.floor(length / SPHERE_DISTANCE)) + 1

        if num_spheres == 1:
            # If the line is shorter than the distance, draw one sphere at the start
            sphere_pos = world_p0
            C.addFrame(f"sphere_{spheres_drawn}") \
                .setShape(ry.ST.sphere, [0.005]) \
                .setPosition(sphere_pos) \
                .setColor([1, 0, 0]) # Red color
            spheres_drawn += 1
        else:
            # Calculate the actual step vector to place num_spheres evenly
            step_vector = vector / (num_spheres - 1)
            for i in range(num_spheres):
                sphere_pos = world_p0 + i * step_vector
                C.addFrame(f"sphere_{spheres_drawn}") \
                    .setShape(ry.ST.sphere, [0.005]) \
                    .setPosition(sphere_pos) \
                    .setColor([1, 0, 0]) # Red color
                spheres_drawn += 1

    elif DRAW_METHOD == 2: # Draw with Cylinder
        center_pos = (world_p0 + world_p1) / 2 # World position of the cylinder's center

        # --- Calculate Cylinder Orientation ---
        z_axis = np.array([0., 0., 1.])
        vector_norm = vector / length
        axis = np.cross(z_axis, vector_norm)
        dot_product = np.clip(np.dot(z_axis, vector_norm), -1.0, 1.0)
        angle = np.arccos(dot_product)

        # Convert axis-angle to quaternion
        if np.linalg.norm(axis) > 1e-6:
            axis = axis / np.linalg.norm(axis) # Normalize axis
            quat = rowan.from_axis_angle(axis, angle)
        elif dot_product < -0.9999:
             quat = np.array([0., 1., 0., 0.])
        else:
            quat = np.array([1., 0., 0., 0.])

        # Add the cylinder frame
        C.addFrame(f"line_{lines}") \
            .setShape(ry.ST.cylinder, [length, 0.005]) \
            .setPosition(center_pos) \
            .setQuaternion(quat) \
            .setColor([1, 0, 0]) # Red color
    else:
        print(f"Warning: Unknown DRAW_METHOD: {DRAW_METHOD}. No markers added.")

    lines += 1 # Increment line counter regardless of method for consistent naming if needed elsewhere

C.getFrame("cameraTop").setQuaternion([0, 1, 0, 0]).setPosition([0, .45, 1.5])
CameraView = ry.CameraView(C)
CameraView.setCamera(C.getFrame("cameraTop"))
fx, fy, cx, cy = CameraView.getFxycxy()

def project_onto_plane(v, n):
    n = n / np.linalg.norm(n)  # Ensure n is a unit vector
    return v - np.dot(v, n) * n
print(C.getFrame("whiteboard").getPosition())
C.view(True)

# ... (keep existing imports and setup) ...

def draw_line_robot(x0_bl, y0_bl, x1_bl, y1_bl):
    """
    Moves the robot's gripper between two points relative to the whiteboard's bottom-left corner.
    After the motion, places visual markers (spheres or cylinder) along the *intended* line segment
    based on the global DRAW_METHOD.

    Args:
        x0_bl, y0_bl (float): Start coordinates relative to the whiteboard *bottom-left corner*.
        x1_bl, y1_bl (float): End coordinates relative to the whiteboard *bottom-left corner*.
    """
    global lines, spheres_drawn, C, WHITEBOARD_TILT, whiteboard_thickness, table_shape, DRAW_METHOD, SPHERE_DISTANCE, fx, fy, cx, cy

    # --- Calculate World Coordinates of Intended Line ---
    whiteboard_frame = C.getFrame("whiteboard")
    whiteboard_pos = whiteboard_frame.getPosition()
    wb_width = table_shape[0]
    wb_height = table_shape[1]

    cos_tilt = np.cos(WHITEBOARD_TILT)
    sin_tilt = np.sin(WHITEBOARD_TILT)
    R_wb = np.array([
        [1, 0,         0],
        [0, cos_tilt, -sin_tilt],
        [0, sin_tilt,  cos_tilt]
    ])

    x0_c = x0_bl - wb_width / 2
    y0_c = y0_bl - wb_height / 2
    x1_c = x1_bl - wb_width / 2
    y1_c = y1_bl - wb_height / 2

    z_offset = whiteboard_thickness / 2 + 0.001
    local_p0 = np.array([x0_c, y0_c, z_offset])
    local_p1 = np.array([x1_c, y1_c, z_offset])

    world_p0 = whiteboard_pos + R_wb @ local_p0
    world_p1 = whiteboard_pos + R_wb @ local_p1
    # --- End World Coordinate Calculation ---


    # --- Robot Motion Simulation ---
    overall_feasible = True
    full_motion_path = [] # Store paths from both segments

    for i in range(2): # 0: move to start, 1: move to end
        target_x_rel = x0_bl if i == 0 else x1_bl
        target_y_rel = y0_bl if i == 0 else y1_bl

        # --- IK for target pose ---
        if i == 1:
            C.addFrame("tmp").setPosition(C.getFrame("penEnd").getPosition())

        man_ik = manip.ManipulationModelling()
        man_ik.setup_inverse_kinematics(C, accumulated_collisions=False)

        man_ik.komo.addObjective([1], ry.FS.positionRel, ["l_gripper", "whiteboard"], ry.OT.eq, scale=[0, 0, 1], target=[0, 0, .1])
        man_ik.komo.addObjective([1], ry.FS.positionRel, ["l_gripper", "whiteboard"], ry.OT.eq, scale=np.diag([1,1,0]), target=[target_x_rel - wb_width/2, target_y_rel - wb_height/2, 0])
        man_ik.komo.addObjective([1], ry.FS.negDistance, ["pen", "whiteboard"], ry.OT.eq, scale=[1], target=[0.01])

        ret_ik = man_ik.solve()
        feasible_ik = man_ik.feasible
        path_ik = man_ik.path

        if not feasible_ik:
            print(f'  -- IK infeasible for {"start" if i==0 else "end"} point')
            overall_feasible = False
            if i == 1 and C.getFrame("tmp"): C.delFrame("tmp")
            break

        target_q = path_ik[0]
        q = C.getJointState()
        C.setJointState(target_q)
        C.addFrame("target").setPosition(C.getFrame("penEnd").getPosition())
        C.setJointState(q)
        target = C.getFrame("target").getPosition()
        C.delFrame("target")
        # --- Path Planning ---
        man_path = manip.ManipulationModelling()
        man_path.setup_point_to_point_motion(C, target_q, accumulated_collisions=False)

        if i == 1:
            delta = np.array(target) - C.getFrame("penEnd").getPosition()
            delta /= np.linalg.norm(delta)
            projection_matrix = np.eye(3) - np.outer(delta, delta)
            man_path.komo.addObjective([], ry.FS.positionDiff, ['penEnd', "tmp"], ry.OT.eq, 1e1 * projection_matrix)

        if i == 0:
            man_path.komo.addObjective([0, .8], ry.FS.negDistance, ["pen", "whiteboard"], ry.OT.ineq, scale=[1], target=[-.1])

        ret_path = man_path.solve()
        feasible_path = man_path.feasible
        path_motion = man_path.path

        if not feasible_path:
            print(f'  -- Path planning infeasible for {"start" if i==0 else "end"} point')
            overall_feasible = False
            if i == 1 and C.getFrame("tmp"): C.delFrame("tmp")
            break

        # --- Animate the planned path segment ---
        print(f"  Animating path segment {i+1}...")
        for t in range(path_motion.shape[0]):
            C.setJointState(path_motion[t])
            #C.view(False, f"Moving segment {i+1} - step {t}")
            #time.sleep(0.02) # Adjust sleep time for desired animation speed

        # Store the path segment
        full_motion_path.append(path_motion)

        # Set final state for this segment before next IK/Path planning
        C.setJointState(path_motion[-1])

        if i == 1 and C.getFrame("tmp"):
            C.delFrame("tmp")
        
    # --- End Robot Motion Simulation ---


    # --- Add Visual Markers (Spheres or Cylinder) ---
    if overall_feasible:
        line_vector = world_p1 - world_p0
        line_length = np.linalg.norm(line_vector)

        if line_length >= 1e-6:
            if DRAW_METHOD == 1: # Draw with Spheres
                if SPHERE_DISTANCE <= 1e-6:
                     print("Warning: SPHERE_DISTANCE must be positive to draw spheres.")
                else:
                    num_spheres = int(np.floor(line_length / SPHERE_DISTANCE)) + 1
                    if num_spheres == 1:
                        sphere_pos = world_p0
                        C.addFrame(f"sphere_{spheres_drawn}") \
                            .setShape(ry.ST.sphere, [0.005]) \
                            .setPosition(sphere_pos) \
                            .setColor([1, 0, 0])
                        spheres_drawn += 1
                    else:
                        step_vector = line_vector / (num_spheres - 1)
                        for k in range(num_spheres):
                            sphere_pos = world_p0 + k * step_vector
                            C.addFrame(f"sphere_{spheres_drawn}") \
                                .setShape(ry.ST.sphere, [0.005]) \
                                .setPosition(sphere_pos) \
                                .setColor([1, 0, 0])
                            spheres_drawn += 1

            elif DRAW_METHOD == 2: # Draw with Cylinder
                center_pos = (world_p0 + world_p1) / 2
                z_axis = np.array([0., 0., 1.])
                vector_norm = line_vector / line_length
                axis = np.cross(z_axis, vector_norm)
                dot_product = np.clip(np.dot(z_axis, vector_norm), -1.0, 1.0)
                angle = np.arccos(dot_product)

                if np.linalg.norm(axis) > 1e-6:
                    axis = axis / np.linalg.norm(axis)
                    quat = rowan.from_axis_angle(axis, angle)
                elif dot_product < -0.9999:
                     quat = np.array([0., 1., 0., 0.])
                else:
                    quat = np.array([1., 0., 0., 0.])

                C.addFrame(f"line_{lines}") \
                    .setShape(ry.ST.cylinder, [line_length, 0.005]) \
                    .setPosition(center_pos) \
                    .setQuaternion(quat) \
                    .setColor([1, 0, 0])
                lines += 1
            else:
                print(f"Warning: Unknown DRAW_METHOD: {DRAW_METHOD}. No markers added.")
        else:
             print("Warning: Skipping near-zero length line for marker drawing.")

    # --- Calculate Camera Projection Vector ---
    # ... (projection logic remains the same) ...
    cam_frame = C.getFrame("cameraTop")
    cam_pos = cam_frame.getPosition()
    cam_quat = cam_frame.getQuaternion()
    R_cam = rowan.to_matrix(cam_quat)
    R_cam_inv = R_cam.T # Inverse rotation
    t_cam = -R_cam_inv @ cam_pos # Inverse translation

    proj_vector = np.array([0., 0.])
    start_point_proj = np.array([0., 0.])

    # Project world_p0
    p_cam0 = R_cam_inv @ world_p0 + t_cam
    if p_cam0[2] > 1e-6: # Check if point is in front of camera
        u0 = fx * p_cam0[0] / p_cam0[2] + cx
        v0 = fy * p_cam0[1] / p_cam0[2] + cy
        start_point_proj = np.array([u0, v0])
    else:
        print("Warning: Start point is behind or too close to the camera.")

    # Project world_p1
    p_cam1 = R_cam_inv @ world_p1 + t_cam
    if p_cam1[2] > 1e-6: # Check if point is in front of camera
        u1 = fx * p_cam1[0] / p_cam1[2] + cx
        v1 = fy * p_cam1[1] / p_cam1[2] + cy
        end_point_proj = np.array([u1, v1])
        proj_vector = end_point_proj - start_point_proj
    else:
        print("Warning: End point is behind or too close to the camera.")
        proj_vector = np.array([0., 0.]) # Return zero vector if projection fails


    # --- Return ---
    return proj_vector

# For num_points = 5
cx = 0.3
cy = 0.3
radius = 0.2
# Calculate the step angle
step = np.pi / 5  # radians

# Calculate all 10 points (5 outer, 5 inner)
# Outer points (i % 2 == 0)
x0 = cx + radius * np.cos(0 * step - np.pi / 2)
y0 = cy + radius * np.sin(0 * step - np.pi / 2)

x2 = cx + radius * np.cos(2 * step - np.pi / 2)
y2 = cy + radius * np.sin(2 * step - np.pi / 2)

x4 = cx + radius * np.cos(4 * step - np.pi / 2)
y4 = cy + radius * np.sin(4 * step - np.pi / 2)

x6 = cx + radius * np.cos(6 * step - np.pi / 2)
y6 = cy + radius * np.sin(6 * step - np.pi / 2)

x8 = cx + radius * np.cos(8 * step - np.pi / 2)
y8 = cy + radius * np.sin(8 * step - np.pi / 2)

# Inner points (i % 2 == 1)
x1 = cx + (radius * 0.5) * np.cos(1 * step - np.pi / 2)
y1 = cy + (radius * 0.5) * np.sin(1 * step - np.pi / 2)

x3 = cx + (radius * 0.5) * np.cos(3 * step - np.pi / 2)
y3 = cy + (radius * 0.5) * np.sin(3 * step - np.pi / 2)

x5 = cx + (radius * 0.5) * np.cos(5 * step - np.pi / 2)
y5 = cy + (radius * 0.5) * np.sin(5 * step - np.pi / 2)

x7 = cx + (radius * 0.5) * np.cos(7 * step - np.pi / 2)
y7 = cy + (radius * 0.5) * np.sin(7 * step - np.pi / 2)

x9 = cx + (radius * 0.5) * np.cos(9 * step - np.pi / 2)
y9 = cy + (radius * 0.5) * np.sin(9 * step - np.pi / 2)

vectors = []

# corner_points 
points = [
    (0., 0),
    (table_shape[0], .0),
    tuple(table_shape),
    (0, table_shape[1]),
    tuple(table_shape)
    # (0.41909385, 0.32310958),
]

# Draw all 10 line segments
# Replace the hardcoded points with the provided ones
# points = [
#     (0.30532767, 0.06591802),
#     (0.36911164, 0.20275572),
#     (0.51178779, 0.22213555),
#     (0.41909385, 0.32310958),
#     (0.42719094, 0.48421075),
#     (0.30112623, 0.39726068),
#     (0.17646194, 0.47671372),
#     (0.20015445, 0.32341007),
#     (0.10529946, 0.21950413),
#     (0.2402701, 0.19801328)
# ]

# points = [
#     (0.3, 0.3),
#     (0.5, 0.3),
#     (0.5, 0.5),
#     (0.3, 0.5),
# ]
# Draw lines between consecutive points
for i in range(len(points)):
    x0, y0 = points[i]
    x1, y1 = points[(i + 1) % len(points)]  # Wrap around to the first point
    vectors.append(draw_line_robot(x0, y0, x1, y1))
    C.setJointState(homeState)


# Save the vectors array to a numpy file
# np.save("vectors.npy", np.array(vectors))
# print("Saved vectors array as vectors.npy", vectors)


C.view(True)
to_stay = ["world", "table", "whiteboard", "cameraTop"]



C.view(True)
for frame in C.getFrameNames():
    if frame not in to_stay and "sphere" not in frame and "line" not in frame:
        C.delFrame(frame)
        

print("Camera: ", CameraView.getFxycxy())
print(10*"\n")

image, _ = CameraView.computeImageAndDepth(C)
plt.imshow(image)
plt.show()
quit()

image_without_background = isolate_red_shapes_from_rgb(image, background_color=(255, 255, 255))
plt.imshow(image_without_background)

# Save the image without background
if SAVE_IMAGE:
    # Get the next available filename
    i = 0
    while os.path.exists(f"output_{i}.png") or os.path.exists(f"output_{i}.npy"):
        i += 1

    filename_png = f"output_{i}.png"
    filename_npy = f"output_{i}.npy"

    # Convert to grayscale for NPY saving
    if image_without_background.ndim == 3 and image_without_background.shape[2] == 3:
        gray = cv2.cvtColor(image_without_background, cv2.COLOR_RGB2GRAY)
    else:
        gray = image_without_background

    # Save .npy
    np.save(filename_npy, gray)
    print(f"Saved grayscale numpy array as {filename_npy}")

    # Save original image (RGB or grayscale) without resizing
    if image_without_background.dtype == np.float32 or image_without_background.max() <= 1.0:
        # Convert float image to 0–255 uint8
        image_to_save = (image_without_background * 255).astype(np.uint8)
    else:
        image_to_save = image_without_background

    # Convert RGB to BGR for OpenCV
    if image_to_save.ndim == 3 and image_to_save.shape[2] == 3:
        image_to_save = cv2.cvtColor(image_to_save, cv2.COLOR_RGB2BGR)

    cv2.imwrite(filename_png, image_to_save)
    print(f"Saved image as {filename_png}")



plt.show()

