import robotic as ry
import numpy as np
import os
import matplotlib.pyplot as plt
import time
import vtamp.environments.bridge.manipulation as manip
import rowan 
from utils import isolate_red_shapes_from_rgb
import cv2

SAVE_IMAGE = False
WHITEBOARD_OFFSET = 0.01  # base offset of whiteboard in cms to table
WHITEBOARD_TILT = 0 # tilt of the whiteboard in degrees

C = ry.Config()
C.addFile(ry.raiPath('scenarios/pandaSingle.g'))
C.getFrame('l_panda_finger_joint1').setJointState([.01])
C.getFrame("table").setColor([1, 1, 1])


WHITEBOARD_TILT = np.deg2rad(WHITEBOARD_TILT)  # Convert to radians
quaternion_wb = [np.cos(WHITEBOARD_TILT / 2), np.sin(WHITEBOARD_TILT / 2), 0, 0]
table_height = C.getFrame("table").getSize()[2] / 2
whiteboard_thickness = 0.005 
whiteboard_height_adjustment = (1 / np.sqrt(2)) * np.sin(WHITEBOARD_TILT) / 2

C.addFrame("whiteboard", "table")\
    .setColor([1, 1, 1])\
    .setShape(ry.ST.box, [1, 1 / np.sqrt(2), whiteboard_thickness])\
    .setRelativePosition([0, .35, table_height + whiteboard_thickness / 2 + whiteboard_height_adjustment + WHITEBOARD_OFFSET])\
    .setQuaternion(quaternion_wb)

# Pen with whom the robot will draw
C.addFrame("pen", "l_gripper").setColor([1, 0, 0]).setShape(ry.ST.cylinder, [.1, .01]).setRelativePosition([0, 0, -.05])

homeState = C.getJointState()

lines = 0

C.getFrame("cameraTop").setQuaternion([0, 1, 0, 0]).setPosition([0, .3, 1.5])
CameraView = ry.CameraView(C)
CameraView.setCamera(C.getFrame("cameraTop"))
fx, fy, cx, cy = CameraView.getFxycxy()

def project_onto_plane(v, n):
    n = n / np.linalg.norm(n)  # Ensure n is a unit vector
    return v - np.dot(v, n) * n


def draw_line(x0, y0, x1, y1):
    """
    Draws a line by moving the robots gripper between two points and 
    placing visual markers if the pen is in contact with the whiteboard.

    Args:
        x0, y0, z0 (float): Start coordinates.
        x1, y1, z1 (float): End coordinates.
    """
    global lines

    fx, fy, cx, cy = CameraView.getFxycxy()
    
    vector = np.array([0., 0.])

    for i in range(2):
        if i == 1:
            C.addFrame("tmp").setPosition(C.getFrame("l_gripper").getPosition())

        man = manip.ManipulationModelling()
        man.setup_inverse_kinematics(C, accumulated_collisions=False)

        #man.komo.addObjective([1], ry.FS.position, ["l_gripper"], ry.OT.eq, 1, target)
        man.komo.addObjective([1], ry.FS.positionRel, ["l_gripper", "whiteboard"], ry.OT.eq, [0, 0, 1], [0, 0, .1])
        if i == 0:
            man.komo.addObjective([1], ry.FS.positionRel, ["l_gripper", "whiteboard"], ry.OT.eq, np.diag([1,1,0]), [-C.getFrame("whiteboard").getSize()[0]/2+x0, -C.getFrame("whiteboard").getSize()[1]/2+y0, 1])
        else:
            man.komo.addObjective([1], ry.FS.positionRel, ["l_gripper", "whiteboard"], ry.OT.eq, np.diag([1,1,0]), [-C.getFrame("whiteboard").getSize()[0]/2+x1, -C.getFrame("whiteboard").getSize()[1]/2+y1, 1])


        man.komo.addObjective([1], ry.FS.negDistance, ["pen", "whiteboard"], ry.OT.eq, 1, [0.01])

        #man.komo.addObjective([], ry.FS.scalarProductXZ, ["pen", "whiteboard"], ry.OT.eq, [100])

        ret = man.solve()
        #print('    IK:', ret)

        feasible = man.feasible
        path = man.path

        q = C.getJointState()
        C.setJointState(path[0])
        C.addFrame("target").setPosition(C.getFrame("l_gripper").getPosition())
        C.setJointState(q)
        target = C.getFrame("target").getPosition()
        C.delFrame("target")

        # C.view(True)
        if not feasible:
            #print('  -- infeasible')
            continue

        man = manip.ManipulationModelling()
        man.setup_point_to_point_motion(C, path[0], accumulated_collisions=False)

        if i == 1:
            delta = np.array(target) - C.getFrame("l_gripper").getPosition()
            delta /= np.linalg.norm(delta)
            projection_matrix = np.eye(3) - np.outer(delta, delta)
            man.komo.addObjective([], ry.FS.positionDiff, ['l_gripper', "tmp"], ry.OT.eq, 1e1 * projection_matrix)

        if i == 0:
            man.komo.addObjective([0, .8], ry.FS.negDistance, ["pen", "whiteboard"], ry.OT.ineq, 1, [-.1])


        ret = man.solve()
        path = man.path
        feasible = feasible and man.feasible

        if not feasible:
            #print('  -- infeasible')
            continue



        for t in range(path.shape[0]):
            neg_dist = C.eval(ry.FS.negDistance, ["pen", "whiteboard"])[0]
            if neg_dist > 0:
                pen_pos = C.getFrame("pen").getPosition()

                # Get the whiteboard's pose (position and quaternion)
                whiteboard_pose = C.getFrame("whiteboard").getPose()
                whiteboard_pos = whiteboard_pose[:3]
                whiteboard_quat = whiteboard_pose[3:]

                # Compute the whiteboard's rotation matrix using rowan
                R = rowan.to_matrix(whiteboard_quat)

                # Transform the pen position into the whiteboard's local coordinate system
                pen_pos_local = np.dot(R.T, pen_pos - whiteboard_pos)

                # Project the pen position onto the whiteboard's plane (z = 0 in local coordinates)
                pen_pos_local[2] = 0

                # Transform the projected position back to the global coordinate system
                pen_pos_projected_global = np.dot(R, pen_pos_local) + whiteboard_pos

                # Add the sphere at the projected position
                if i == 1:
                    C.addFrame(f"circle_{lines}_{t}")\
                        .setShape(ry.ST.sphere, size=[.01])\
                        .setPosition(pen_pos_projected_global)\
                        .setColor([1, 0, 0])
                    
                    if t==0:
                        camera_point = rowan.to_matrix(C.getFrame("cameraTop").getQuaternion()) @ pen_pos_projected_global + C.getFrame("cameraTop").getPosition()
                        u = fx * camera_point[0] / camera_point[2] + cx
                        v = fy * camera_point[1] / camera_point[2] + cy

                        vector -= np.array([u, v])
                    if t==path.shape[0]-1:
                        camera_point = rowan.to_matrix(C.getFrame("cameraTop").getQuaternion()) @ pen_pos_projected_global + C.getFrame("cameraTop").getPosition()
                        u = fx * camera_point[0] / camera_point[2] + cx
                        v = fy * camera_point[1] / camera_point[2] + cy

                        vector += np.array([u, v])
            C.setJointState(path[t])

            #C.view(False)
            # time.sleep(0.05)

        lines += 1
        
    return vector
# For num_points = 5
cx = 0.3
cy = 0.3
radius = 0.2
# Calculate the step angle
step = np.pi / 5  # radians

# Calculate all 10 points (5 outer, 5 inner)
# Outer points (i % 2 == 0)
x0 = cx + radius * np.cos(0 * step - np.pi / 2)
y0 = cy + radius * np.sin(0 * step - np.pi / 2)

x2 = cx + radius * np.cos(2 * step - np.pi / 2)
y2 = cy + radius * np.sin(2 * step - np.pi / 2)

x4 = cx + radius * np.cos(4 * step - np.pi / 2)
y4 = cy + radius * np.sin(4 * step - np.pi / 2)

x6 = cx + radius * np.cos(6 * step - np.pi / 2)
y6 = cy + radius * np.sin(6 * step - np.pi / 2)

x8 = cx + radius * np.cos(8 * step - np.pi / 2)
y8 = cy + radius * np.sin(8 * step - np.pi / 2)

# Inner points (i % 2 == 1)
x1 = cx + (radius * 0.5) * np.cos(1 * step - np.pi / 2)
y1 = cy + (radius * 0.5) * np.sin(1 * step - np.pi / 2)

x3 = cx + (radius * 0.5) * np.cos(3 * step - np.pi / 2)
y3 = cy + (radius * 0.5) * np.sin(3 * step - np.pi / 2)

x5 = cx + (radius * 0.5) * np.cos(5 * step - np.pi / 2)
y5 = cy + (radius * 0.5) * np.sin(5 * step - np.pi / 2)

x7 = cx + (radius * 0.5) * np.cos(7 * step - np.pi / 2)
y7 = cy + (radius * 0.5) * np.sin(7 * step - np.pi / 2)

x9 = cx + (radius * 0.5) * np.cos(9 * step - np.pi / 2)
y9 = cy + (radius * 0.5) * np.sin(9 * step - np.pi / 2)

vectors = []


# Draw all 10 line segments
# Replace the hardcoded points with the provided ones
points = [
    (0.30532767, 0.06591802),
    (0.36911164, 0.20275572),
    (0.51178779, 0.22213555),
    (0.41909385, 0.32310958),
    (0.42719094, 0.48421075),
    (0.30112623, 0.39726068),
    (0.17646194, 0.47671372),
    (0.20015445, 0.32341007),
    (0.10529946, 0.21950413),
    (0.2402701, 0.19801328)
]
# points = [
#     (0.3, 0.3),
#     (0.5, 0.3),
#     (0.5, 0.5),
#     (0.3, 0.5),
# ]
# Draw lines between consecutive points
for i in range(len(points)):
    x0, y0 = points[i]
    x1, y1 = points[(i + 1) % len(points)]  # Wrap around to the first point
    vectors.append(draw_line(x0, y0, x1, y1))
    C.setJointState(homeState)


# Save the vectors array to a numpy file
np.save("vectors.npy", np.array(vectors))
print("Saved vectors array as vectors.npy", vectors)


C.view(True)
to_stay = ["world", "table", "whiteboard", "cameraTop"]



C.getFrame("cameraTop").setQuaternion([0, 1, 0, 0]).setPosition([0, .3, 1.5])
C.view(True)
for frame in C.getFrameNames():
    if frame not in to_stay and "circle_" not in frame:
        C.delFrame(frame)
        

print("Camera: ", CameraView.getFxycxy())
print(10*"\n")

image, _ = CameraView.computeImageAndDepth(C)
plt.imshow(image)
plt.show()

image_without_background = isolate_red_shapes_from_rgb(image, background_color=(255, 255, 255))
plt.imshow(image_without_background)

# Save the image without background
if SAVE_IMAGE:
    # Get the next available filename
    i = 0
    while os.path.exists(f"output_{i}.png") or os.path.exists(f"output_{i}.npy"):
        i += 1

    filename_png = f"output_{i}.png"
    filename_npy = f"output_{i}.npy"

    # Convert to grayscale for NPY saving
    if image_without_background.ndim == 3 and image_without_background.shape[2] == 3:
        gray = cv2.cvtColor(image_without_background, cv2.COLOR_RGB2GRAY)
    else:
        gray = image_without_background

    # Save .npy
    np.save(filename_npy, gray)
    print(f"Saved grayscale numpy array as {filename_npy}")

    # Save original image (RGB or grayscale) without resizing
    if image_without_background.dtype == np.float32 or image_without_background.max() <= 1.0:
        # Convert float image to 0–255 uint8
        image_to_save = (image_without_background * 255).astype(np.uint8)
    else:
        image_to_save = image_without_background

    # Convert RGB to BGR for OpenCV
    if image_to_save.ndim == 3 and image_to_save.shape[2] == 3:
        image_to_save = cv2.cvtColor(image_to_save, cv2.COLOR_RGB2BGR)

    cv2.imwrite(filename_png, image_to_save)
    print(f"Saved image as {filename_png}")



plt.show()

