import os
import sys
import numpy as np

from transformations import rotation_matrix, quaternion_from_matrix
from transformations import quaternion_matrix, quaternion_from_euler


def calc_box_init_FluidShake(dis_x, dis_z, box_info):
    dis_x, dis_z, height, border, bar_position_y, bar_diameter, \
            bar_length_y, bar_length_x, table_height = box_info

    center = np.array([0., table_height, 0.])
    quat = np.array([1., 0., 0., 0.])
    boxes = []

    # floor
    halfEdge = np.array([dis_x / 2., border / 2., dis_z / 2.])
    boxes.append([halfEdge, center, quat])

    # left wall
    halfEdge = np.array([border / 2., (height + border) / 2., dis_z / 2.])
    boxes.append([halfEdge, center, quat])

    # right wall
    boxes.append([halfEdge, center, quat])

    # back wall
    halfEdge = np.array([(dis_x + border * 2) / 2., (height + border) / 2., border / 2.])
    boxes.append([halfEdge, center, quat])

    # front wall
    boxes.append([halfEdge, center, quat])

    ## right wall for side bar
    halfEdge = np.array([bar_diameter / 2., (bar_position_y + border) / 2., bar_diameter / 2.])
    boxes.append([halfEdge, center, quat])

    ## right bar
    halfEdge = np.array([bar_length_x / 2., bar_diameter / 2., bar_diameter / 2.])
    boxes.append([halfEdge, center, quat])  # upper side bar
    boxes.append([halfEdge, center, quat])  # lower side bar

    halfEdge = np.array([bar_diameter / 2., bar_length_y / 2., bar_diameter / 2.])
    boxes.append([halfEdge, center, quat])  # middle side bar


    return boxes


def calc_shape_states_FluidShake(x_curr, x_last, z_curr, z_last, box_info):
    dis_x, dis_z, height, border, bar_position_y, bar_diameter, \
            bar_length_y, bar_length_x, table_height = box_info

    quat = np.array([1., 0., 0., 0.])

    states = np.zeros((9, 14))
    state_offset = np.array((0., table_height, 0.))

    # floor
    states[0, :3] = np.array([x_curr, border / 2., z_curr])
    states[0, 3:6] = np.array([x_last, border / 2., z_last])

    # left wall
    states[1, :3] = np.array([x_curr - (dis_x + border) / 2., (height + border) / 2., z_curr])
    states[1, 3:6] = np.array([x_last - (dis_x + border) / 2., (height + border) / 2., z_last])

    # right wall
    states[2, :3] = np.array([x_curr + (dis_x + border) / 2., (height + border) / 2., z_curr])
    states[2, 3:6] = np.array([x_last + (dis_x + border) / 2., (height + border) / 2., z_last])

    # back wall
    states[3, :3] = np.array([x_curr, (height + border) / 2., z_curr - (dis_z + border) / 2.])
    states[3, 3:6] = np.array([x_last, (height + border) / 2., z_last - (dis_z + border) / 2.])

    # front wall
    states[4, :3] = np.array([x_curr, (height + border) / 2., z_curr + (dis_z + border) / 2.])
    states[4, 3:6] = np.array([x_last, (height + border) / 2., z_last + (dis_z + border) / 2.])

    ## right wall for side bar
    states[5, :3] = np.array([x_curr + (dis_x + border) / 2., (bar_position_y + border) / 2., z_curr])
    states[5, 3:6] = np.array([x_last + (dis_x + border) / 2., (bar_position_y + border) / 2., z_last])

    ## right bar
    states[6, :3] = np.array([
        x_curr + dis_x / 2. + border + bar_length_x / 2.,
        border + bar_position_y - bar_diameter / 2.,
        z_curr])
    states[6, 3:6] = np.array([
        x_last + dis_x / 2. + border + bar_length_x / 2.,
        border + bar_position_y - bar_diameter / 2.,
        z_last])

    states[7, :3] = np.array([
        x_curr + dis_x / 2. + border + bar_length_x / 2.,
        border + bar_position_y - bar_length_y + bar_diameter / 2.,
        z_curr])
    states[7, 3:6] = np.array([
        x_last + dis_x / 2. + border + bar_length_x / 2.,
        border + bar_position_y - bar_length_y + bar_diameter / 2.,
        z_last])

    states[8, :3] = np.array([
        x_curr + dis_x / 2. + border + bar_length_x + bar_diameter / 2.,
        border + bar_position_y - bar_length_y / 2.,
        z_curr])
    states[8, 3:6] = np.array([
        x_last + dis_x / 2. + border + bar_length_x + bar_diameter / 2.,
        border + bar_position_y - bar_length_y / 2.,
        z_last])


    # add table height
    states[:, :3] += state_offset
    states[:, 3:6] += state_offset

    # orientation
    states[:, 6:10] = quat
    states[:, 10:] = quat

    return states


def quatFromAxisAngle(axis, angle):
    axis /= np.linalg.norm(axis)

    half = angle * 0.5
    w = np.cos(half)

    sin_theta_over_two = np.sin(half)
    axis *= sin_theta_over_two

    quat = np.array([axis[0], axis[1], axis[2], w])

    return quat


def calc_container_boxes_FluidPour(pos, angle, direction, size, border=0.02):
    boxes = []
    hide_shape = []

    dx, dy, dz = size
    r_mtx = rotation_matrix(angle, direction)
    quat = quaternion_from_matrix(r_mtx)

    # bottom
    halfEdge = np.array([dx / 2. + border, border / 2., dz / 2. + border])
    center = np.array([0., -(dy + border) / 2., 0., 1.])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    # left
    halfEdge = np.array([border / 2., dy / 2. + border, dz / 2. + border])
    center = np.array([-(dx + border) / 2., 0., 0., 1.])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(1)

    # right
    center = np.array([(dx + border) / 2., 0., 0., 1.])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(1)

    # back
    halfEdge = np.array([dx / 2. + border, dy / 2. + border, border / 2.])
    center = np.array([0., 0., -(dz + border) / 2., 1.])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(1)

    # front
    center = np.array([0., 0., (dz + border) / 2., 1.])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(1)

    # top bars
    halfEdge = np.array([border / 2., border / 2., dz / 2. + border])
    center = np.array([(dx + border) / 2., (dy + border) / 2., 0., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    center = np.array([-(dx + border) / 2., (dy + border) / 2., 0., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    halfEdge = np.array([dx / 2. + border, border / 2., border / 2.])
    center = np.array([0, (dy + border) / 2., (dz + border) / 2., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    center = np.array([0, (dy + border) / 2., -(dz + border) / 2., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    # side bars
    halfEdge = np.array([border / 2., dy / 2. + border, border / 2.])
    center = np.array([(dx + border) / 2., 0., (dz + border) / 2., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    center = np.array([(dx + border) / 2., 0., -(dz + border) / 2., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    center = np.array([-(dx + border) / 2., 0., -(dz + border) / 2., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    center = np.array([-(dx + border) / 2., 0., (dz + border) / 2., 1])
    center = r_mtx.dot(center)[:3] + pos
    boxes.append([halfEdge, center, quat])
    hide_shape.append(0)

    return boxes, np.array(hide_shape)



def calc_kuka_ee_state_FluidManip(pos, angle, direction, size, border=0.02):
    dx, dy, dz = size
    r_mtx = rotation_matrix(angle, direction)
    quat = quaternion_from_matrix(np.matmul(
        r_mtx,
        quaternion_matrix(quaternion_from_euler(np.pi/2., np.pi/2., np.pi/2.))))

    center = np.array([
        -(dx + border) / 2. - 0.25,
        -(dy + border) / 4.,
        0.,
        1.])
    center = r_mtx.dot(center)[:3] + pos

    return center, quat


def calc_table_shapes_FluidManip(table_size, border, table_height):
    boxes = []

    halfEdge = np.array([table_size / 2., border / 2., table_size / 2.])
    center = np.array([0.4, table_height - border / 2., 0.])
    quat = quaternion_from_euler(0., 0., 0.)

    boxes.append([halfEdge, center, quat])

    return boxes


def calc_kuka_ee_state_FluidShakeWithIce(
    pos, angle, direction, size, border=0.02):

    dx, dy, dz = size
    r_mtx = rotation_matrix(angle, direction)
    quat = quaternion_from_matrix(np.matmul(
        r_mtx,
        quaternion_matrix(quaternion_from_euler(np.pi/2., 0., 0.))))

    center = np.array([
        (dx + border) / 2. + 0.2,
        dy - 0.08,
        0.,
        1.])
    center = r_mtx.dot(center)[:3] + pos

    return center, quat


def calc_gripper_shapes_FluidShakeWithIce(
    x_curr, x_last, z_curr, z_last, box_info, gripper_info):

    dis_x, dis_z, height, border, bar_position_y, bar_diameter, \
            bar_length_y, bar_length_x, table_height = box_info

    palm_x, palm_y, palm_z, finger_x, finger_y, finger_z, \
            finger_dis = gripper_info

    quat = np.array([1., 0., 0., 0.])

    states = np.zeros((3, 14))
    state_offset = np.array((0., table_height - 0.08, 0.))

    # palm
    states[0, :3] = np.array([
        x_curr + dis_x / 2. + border + bar_length_x * 2. / 3.,
        border + bar_position_y + finger_y + palm_y / 2.,
        z_curr])
    states[0, 3:6] = np.array([
        x_last + dis_x / 2. + border + bar_length_x * 2. / 3.,
        border + bar_position_y + finger_y + palm_y / 2.,
        z_last])

    # finger 0
    states[1, :3] = np.array([
        x_curr + dis_x / 2. + border + bar_length_x * 2. / 3.,
        border + bar_position_y + finger_y / 2.,
        z_curr + finger_dis / 2. + finger_z / 2.])
    states[1, 3:6] = np.array([
        x_last + dis_x / 2. + border + bar_length_x * 2. / 3.,
        border + bar_position_y + finger_y / 2.,
        z_last + finger_dis / 2. + finger_z / 2.])

    # finger 1
    states[2, :3] = np.array([
        x_curr + dis_x / 2. + border + bar_length_x * 2. / 3.,
        border + bar_position_y + finger_y / 2.,
        z_curr - finger_dis / 2. - finger_z / 2.])
    states[2, 3:6] = np.array([
        x_last + dis_x / 2. + border + bar_length_x * 2. / 3.,
        border + bar_position_y + finger_y / 2.,
        z_last - finger_dis / 2. - finger_z / 2.])

    # add table height
    states[:, :3] += state_offset
    states[:, 3:6] += state_offset

    # orientation
    states[:, 6:10] = quat
    states[:, 10:] = quat

    return states


def calc_table_shapes_FluidShakeWithIce(table_size, border, table_height):
    boxes = []

    halfEdge = np.array([table_size / 2., border / 2., table_size / 2.])
    center = np.array([0., table_height - border / 2., 0.])
    quat = quaternion_from_euler(0., 0., 0.)

    boxes.append([halfEdge, center, quat])

    return boxes
