# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: percent
#       format_version: '1.3'
#       jupytext_version: 1.17.0
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# %%
import robotic as ry
import numpy as np
import rowan

from vtamp.environments.bridge.tasks import BuildPlanarTriangle, PushRed
from vtamp.environments.bridge.env import Environment, BridgeEnv

# %% [markdown]
# # Pick And Place

# %%
env = BuildPlanarTriangle(goal_str="test")
config = env.setup_cfg()
config.view()

# %%
center_x = 0.2
center_y = 0.2
block_size_z = config.getFrame("block_red").getSize()[2]
block_size_z

# %%
# create a 2‑phase KOMO problem with collisions
komo = ry.KOMO(config, 4, 1, 1, True)

# add homing (phase‑0) and velocity (phase‑1) control costs
komo.addControlObjective([], 0, 1e-2)
komo.addControlObjective([], 1, 1e-1)

# enforce zero accumulated collisions and joint‑limit constraints
komo.addObjective([], ry.FS.accumulatedCollisions, [], ry.OT.eq, [1e0])
komo.addObjective([], ry.FS.jointLimits,           [], ry.OT.ineq, [1e0])


# proper grasp: align gripper with box
box_size = komo.getConfig().getFrame("block_red").getSize()[:3]
margin = 0.02
komo.addObjective([1.0], ry.FS.positionRel, "l_gripper","block_red"], ry.OT.eq, np.array([[1e1,0,0]]))

# stay within the box’s y‑z cross‑section ie (inside +- the limits)
komo.addObjective([1.0], ry.FS.positionRel, ["l_gripper","block_red"], ry.OT.ineq, np.array([[0,1,0], [0,0,1]])*1e1, 0.5*box_size - margin)
komo.addObjective([1.0], ry.FS.positionRel, ["l_gripper","block_red"], ry.OT.ineq, np.array([[0,1,0], 0,0,1]])*(-1e1), -0.5*box_size + margin)

# align gripper axes orthogonal to the box faces
komo.addObjective([0.8, 1.0], ry.FS.scalarProductXY, ["l_gripper","block_red"], ry.OT.eq, [1e0])
komo.addObjective([0.8, 1.0], ry.FS.scalarProductXZ,
                    ["l_gripper","block_red"],
                    ry.OT.eq,
                    [1e0])

# avoid collision between palm and box right before grasp
komo.addObjective([0.7, 1.0],
                    ry.FS.distance,
                    ["l_palm","block_red"],
                    ry.OT.ineq,
                    [1e1],
                    [-0.001])

# Kinematic switch: add object to kinematic chain of robot
komo.addModeSwitch([1., -1.], ry.SY.stable, ["l_gripper", "block_red"], firstSwitch=True)


# Placement
tableSize = config.getFrame("table").getSize()[:3]
relPos = 0.5 * (box_size[2] + tableSize[2])
komo.addObjective([2.0],
                    ry.FS.positionDiff,
                    ["block_red","table"],
                    ry.OT.eq,
                    1e1*np.array([[0,0,1]]),
                    np.array([0., 0., relPos]))

# keep it within the table’s top‑surface footprint
komo.addObjective([2.0],
                    ry.FS.positionRel,
                    ["block_red","table"],
                    ry.OT.ineq,
                    1e1*np.array([[1,0,0],
                                  [0,1,0]]),
                    0.5*tableSize - margin)

komo.addObjective([2.0],
                    ry.FS.positionRel,
                    ["block_red","table"],
                    ry.OT.ineq,
                    -1e1*np.array([[1,0,0],
                                   [0,1,0]]),
                    -0.5*tableSize + margin)

# enforce Z‑up orientation of the box
zVectorTarget = np.array([0.,0.,1.])
komo.addObjective([1.8, 2.0],
                    ry.FS.vectorZ,
                    ["block_red"],
                    ry.OT.eq,
                    [0.5],
                    zVectorTarget)

# align the table’s X‑ and Y‑axes with the box
komo.addObjective([1.8, 2.0],
                    ry.FS.scalarProductXZ,
                    ["table","block_red"],
                    ry.OT.eq,
                    [1e0])

komo.addObjective([1.8, 2.0],
                    ry.FS.scalarProductYZ,
                    ["table","block_red"],
                    ry.OT.eq,
                    [1e0])

# avoid collision between palm and table during placement
komo.addObjective([1.7, 2.0],
                    ry.FS.distance,
                    ["l_palm","table"],
                    ry.OT.ineq,
                    [1e1],
                    [-0.001])

table_offset = config.getFrame("table").getPosition()[2] + config.getFrame("table").getSize()[2]*.5
target_pos = [.1, .1, table_offset + block_size_z * 0.5]
komo.addObjective([2.], ry.FS.position, ["block_red"], ry.OT.eq, scale=[1e1], target=target_pos)

# Break contact
komo.addModeSwitch([2., -1.], ry.SY.stable, ["table", "block_red"], firstSwitch=False)

# %%
# ——————————————————————————————————————————————
# (3) BLUE BLOCK pick @ t=3.0
# ——————————————————————————————————————————————
box_blue = "block_blue"
box_size_b = config.getFrame(box_blue).getSize()[:3]
margin   = 0.02

# center gripper on blue along its X axis
komo.addObjective([3.0],
                  ry.FS.positionRel,
                  ["l_gripper", box_blue],
                  ry.OT.eq,
                  np.array([[1,0,0]])*1e1)

# keep within its Y‑Z cross–section
komo.addObjective([3.0],
                  ry.FS.positionRel,
                  ["l_gripper", box_blue],
                  ry.OT.ineq,
                  np.array([[0,1,0],[0,0,1]])*1e1,
                  0.5*box_size_b - margin)
komo.addObjective([3.0],
                  ry.FS.positionRel,
                  ["l_gripper", box_blue],
                  ry.OT.ineq,
                  -1e1*np.array([[0,1,0],[0,0,1]]),
                  -0.5*box_size_b + margin)

# align gripper axes with the box
komo.addObjective([2.8, 3.0],
                  ry.FS.scalarProductXY,
                  ["l_gripper", box_blue],
                  ry.OT.eq,
                  [1e0])
komo.addObjective([2.8, 3.0],
                  ry.FS.scalarProductXZ,
                  ["l_gripper", box_blue],
                  ry.OT.eq,
                  [1e0])

# avoid palm→box collision
komo.addObjective([2.7, 3.0],
                  ry.FS.distance,
                  ["l_palm", box_blue],
                  ry.OT.ineq,
                  [1e1],
                  [-0.001])

# attach the blue block to a new grasp frame
komo.addModeSwitch([3., -1.], ry.SY.stable, ["l_gripper", "block_blue"], firstSwitch=False)


# ——————————————————————————————————————————————
# (4) BLUE BLOCK place atop RED @ t=4.0
# ——————————————————————————————————————————————
# —— compute the vertical stacking offset
box_b = config.getFrame("block_blue").getSize()[:3]
box_r = config.getFrame("block_red" ).getSize()[:3]
relPos = 0.5*(box_b[2] + box_r[2])   # half‐sum of heights

# (1) lift blue onto red along Z so its bottom just meets red’s top
komo.addObjective([4.0],
                  ry.FS.positionDiff,
                  ["block_blue","block_red"],
                  ry.OT.eq,
                  1e1*np.array([[0,0,1]]),
                  [0., 0., relPos])

# (2) center it horizontally over the red block
komo.addObjective([4.0],
                  ry.FS.positionRel,
                  ["block_blue","block_red"],
                  ry.OT.eq,
                  1e1*np.array([[1,0,0],
                                [0,1,0]]),
                  [0., 0., 0.])

## (3) then spin it about X so its local Z points along red‑Y (i.e. perpendicular to the red)
komo.addObjective([3.8,4.0],
                  ry.FS.scalarProductYZ,
                  ["block_red", "block_blue"],
                  ry.OT.eq,
                  [0.5],
                  [1.])

# align the x axes bewteen red and blue for proper placement
komo.addObjective([3.8,4.0],
                  ry.FS.scalarProductXX,
                  ["block_blue", "block_red"],
                  ry.OT.eq,
                  [1e1],
                  [1.])

# (4) keep your palm clear during the final drop
komo.addObjective([3.7,4.0],
                  ry.FS.distance,
                  ["l_palm","block_blue"],
                  ry.OT.ineq,
                  [1e1],
                  [-0.001])

# break contact
komo.addModeSwitch([4., -1.], ry.SY.stable, ["table", "block_blue"], firstSwitch=False)

# %%
ret = ry.NLP_Solver(komo.nlp(), verbose=4).solve()
print(ret)
komo.view(False, "IK solution")

# %%

# %% [markdown]
# # Box push

# %%
env = BuildPlanarTriangle(goal_str="test")
config = env.setup_cfg()
config.delFrame("block_red")
config.delFrame("block_green")
config.delFrame("block_blue")
config.addFrame("big_red_block") \
    .setPosition([-.1, .2, .7]) \
    .setQuaternion(rowan.from_euler(0., 0., -np.pi * 1.5, convention="xyz")) \
    .setShape(ry.ST.ssBox, size=[.1, .1, .1, 0.005]) \
    .setColor([.8, .2, .25]) \
    .setContact(1) \
    .setMass(.1)

config.addFrame("target_pose") \
    .setPosition([.3, .3, .7]) \
    .setQuaternion(rowan.from_euler(0., 0., np.pi * 1.5, convention="xyz")) \
    .setShape(ry.ST.ssBox, size=[.1, .1, .1, 0.005]) \
    .setColor([.2, .8, .25, .2]) \
    .setContact(1) \
    .setMass(.1)
config.view()

# %%
komo = ry.KOMO(config, 2, 1, 2, False)

komo.addControlObjective([], 0, 1e-1)
komo.addControlObjective([], 1, 1e-1)
komo.addControlObjective([], 2, 1e-1)
komo.addObjective([], ry.FS.jointLimits, [], ry.OT.ineq, [1e0])

# Compute push params
block = config.getFrame("big_red_block")
target = config.getFrame("target_pose")
block_size = block.getSize()
start_pos = block.getPosition()
start_pos -= [block_size[0], block_size[1], 0]
end_pos = target.getPosition()
delta = end_pos - start_pos
delta /= np.linalg.norm(delta)
mat = np.eye(3) - np.outer(delta, delta)

# avoid table collisions
komo.addObjective([], ry.FS.distance, ["l_gripper", "table"], ry.OT.ineq, [1e1], [.01])
komo.addObjective([1., 2.], ry.FS.vectorZ, ["l_gripper"], ry.OT.eq, [1e1], [0., 0., 1.])
komo.addObjective([1., 2.], ry.FS.vectorY, ["l_gripper"], ry.OT.eq, [1e1], delta)

# execute_push
komo.addObjective([1.], ry.FS.position, ["l_gripper"], ry.OT.eq, [1e1], start_pos)
komo.addObjective([2.], ry.FS.position, ["l_gripper"], ry.OT.eq, [1e1], end_pos)

# %%
ret = ry.NLP_Solver(komo.nlp(), verbose=4).solve()
print(ret)
komo.view(False, "IK solution")

# %%
env2 = BridgeEnv(task=PushRed(""))
env2.C.delFrame("big_red_block")
env2.C.addFrame("big_red_block") \
    .setPosition([-.1, .2, .7]) \
    .setQuaternion(rowan.from_euler(0., 0., -np.pi * 1.5, convention="xyz")) \
    .setShape(ry.ST.ssBox, size=[.1, .1, .1, 0.005]) \
    .setColor([.8, .2, .25]) \
    .setContact(1) \
    .setMass(.1)

env2.C.addFrame("target_pose") \
    .setPosition([.3, .3, .7]) \
    .setQuaternion(rowan.from_euler(0., 0., np.pi * 1.5, convention="xyz")) \
    .setShape(ry.ST.ssBox, size=[.1, .1, .1, 0.005]) \
    .setColor([.2, .8, .25, .2]) \
    .setContact(1) \
    .setMass(.1)
env2.C.view()

# %%
env2.getState()

# %%
