import os
import time
import cv2
import sys
import copy
import numpy as np

from env_utils import calc_box_init_FluidShake, calc_shape_states_FluidShake
from env_utils import calc_container_boxes_FluidPour
from env_utils import calc_table_shapes_FluidManip, calc_kuka_ee_state_FluidManip
from env_utils import calc_table_shapes_FluidShakeWithIce, calc_kuka_ee_state_FluidShakeWithIce
from env_utils import calc_gripper_shapes_FluidShakeWithIce
from env_utils import quatFromAxisAngle
from utils import set_seed, rand_int, rand_float

from kuka_container import KukaFleXContainer



def render_image_from_PyFleX(pyflex, height, width, debug_info, draw_objects=1, draw_shadow=1):
    while True:
        img = pyflex.render(draw_objects=draw_objects, draw_shadow=draw_shadow)
        img = img.reshape(height, width, 4)
        img = img[..., :3][..., ::-1]
        if (img != 0).any():
            break
        else:
            print('empty image at %s' % debug_info)
    return img



class PhysicsEngine(object):

    def __init__(self):
        pass

    def init(self, scene_params):
        pass

    def step(self):
        pass

    def get_viewMatrix(self, camPos=None, camoAngle=None):
        if camPos is not None:
            self.pyflex.set_camPos(camPos)
        if camAngle is not None:
            self.pyflex.set_camAngle(camAngle)

        viewMatrix = self.pyflex.get_viewMatrix().reshape(4, 4)
        return viewMatrix

    def get_projMatrix(self):
        projMatrix = self.pyflex.get_projMatrix().reshape(4, 4)
        return projMatrix

    def render_img(self, camPos=None, camAngle=None, width=360, height=360, BGR2RGB=False):
        if camPos is not None:
            self.pyflex.set_camPos(camPos)
        if camAngle is not None:
            self.pyflex.set_camAngle(camAngle)

        viewMatrix = self.pyflex.get_viewMatrix().reshape(4, 4)
        projMatrix = self.pyflex.get_projMatrix().reshape(4, 4)

        # self.pyflex.set_hideShapes(self.hide_shape)
        img = render_image_from_PyFleX(
            self.pyflex, self.args.screenHeight, self.args.screenWidth,
            debug_info="")
        img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)

        if BGR2RGB:
            img = img[..., ::-1]

        return img, viewMatrix, projMatrix

    def get_state(self):
        n_particles = self.pyflex.get_n_particles()
        positions = self.pyflex.get_positions().reshape((n_particles, 4))[:, :3]
        return positions

    def set_action(self, action):
        self.action = action.copy()




class FluidManipEngine(PhysicsEngine):

    def __init__(self, args):
        super(FluidManipEngine, self).__init__()

        import pyflex
        self.pyflex = pyflex
        self.pyflex.set_screenWidth(args.screenWidth)
        self.pyflex.set_screenHeight(args.screenHeight)
        self.pyflex.init()

        self.args = args

    def init(self, scene_params=None, context=None):

        self.env_idx = 17

        if scene_params is None:
            border = 0.02
            radius = 0.055

            dim_x_fluid_pourer = 10
            dim_y_fluid_pourer = 20
            dim_z_fluid_pourer = 10
            size_x_pourer = dim_x_fluid_pourer * radius - 0.06
            size_y_pourer = 1.2
            size_z_pourer = dim_z_fluid_pourer * radius - 0.06

            pourer_lim_x = [-0.7, -0.5]
            pourer_lim_z = [-0.35, 0.35]
            x_pourer = rand_float(pourer_lim_x[0], pourer_lim_x[1] - 0.1)
            y_pourer = 1.3
            z_pourer = rand_float(pourer_lim_z[0], pourer_lim_z[1])
            x_fluid_pourer = x_pourer
            y_fluid_pourer = y_pourer - size_y_pourer / 2.
            z_fluid_pourer = z_pourer

            dim_x_fluid_catcher = 25
            dim_y_fluid_catcher = 5
            dim_z_fluid_catcher = 25
            size_x_catcher = dim_x_fluid_catcher * radius - 0.06
            size_y_catcher = 0.7
            size_z_catcher = dim_z_fluid_catcher * radius - 0.06

            x_catcher = 0.4
            y_catcher = size_y_catcher / 2. + border
            z_catcher = 0.
            x_fluid_catcher = x_catcher
            y_fluid_catcher = border
            z_fluid_catcher = z_catcher

            draw_mesh = 1

            scene_params = np.array([
                x_fluid_pourer - (dim_x_fluid_pourer - 1) / 2. * radius,
                y_fluid_pourer,
                z_fluid_pourer - (dim_z_fluid_pourer - 1) / 2. * radius,
                dim_x_fluid_pourer,
                dim_y_fluid_pourer,
                dim_z_fluid_pourer,
                x_fluid_catcher - (dim_x_fluid_catcher - 1) / 2. * radius,
                y_fluid_catcher,
                z_fluid_catcher - (dim_z_fluid_catcher - 1) / 2. * radius,
                dim_x_fluid_catcher,
                dim_y_fluid_catcher,
                dim_z_fluid_catcher,
                draw_mesh])

            context = [
                border, radius,
                pourer_lim_x, pourer_lim_z,
                x_pourer, y_pourer, z_pourer,
                size_x_pourer, size_y_pourer, size_z_pourer,
                x_catcher, y_catcher, z_catcher,
                size_x_catcher, size_y_catcher, size_z_catcher]



        self.scene_params = scene_params.copy()
        self.context = context.copy()

        border, radius, \
                pourer_lim_x, pourer_lim_z, \
                x_pourer, y_pourer, z_pourer, \
                size_x_pourer, size_y_pourer, size_z_pourer, \
                x_catcher, y_catcher, z_catcher, \
                size_x_catcher, size_y_catcher, size_z_catcher = context

        print(scene_params)

        self.pyflex.set_scene(self.env_idx, scene_params, 0)
        self.pyflex.set_fluid_color(np.array([0.529, 0.808, 0.98, 0.0]))

        self.pyflex.set_floorScaleSize(0.3)


        # set pourer
        pourer_pos = np.array([x_pourer, y_pourer, z_pourer])
        pourer_size = np.array([size_x_pourer, size_y_pourer, size_z_pourer])

        boxes_pourer, hide_shape_pourer = calc_container_boxes_FluidPour(
            pos=pourer_pos,
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=pourer_size,
            border=border)

        for i in range(len(boxes_pourer)):
            halfEdge = boxes_pourer[i][0]
            center = boxes_pourer[i][1]
            quat = boxes_pourer[i][2]
            self.pyflex.add_box(halfEdge, center, quat)

        # set catcher
        catcher_pos = np.array([x_catcher, y_catcher, z_catcher])
        catcher_size = np.array([size_x_catcher, size_y_catcher, size_z_catcher])

        boxes_catcher, hide_shape_catcher = calc_container_boxes_FluidPour(
            pos=catcher_pos,
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=catcher_size,
            border=border)

        for i in range(len(boxes_catcher)):
            halfEdge = boxes_catcher[i][0]
            center = boxes_catcher[i][1]
            quat = boxes_catcher[i][2]
            self.pyflex.add_box(halfEdge, center, quat)

        # record all necessary information
        self.pourer_lim_x = pourer_lim_x
        self.pourer_lim_z = pourer_lim_z

        self.border = border
        self.pourer_pos = pourer_pos
        self.catcher_pos = catcher_pos
        self.pourer_size = pourer_size
        self.catcher_size = catcher_size

        self.boxes_pourer = boxes_pourer
        self.boxes_catcher = boxes_catcher
        self.hide_shape = np.concatenate([hide_shape_pourer, hide_shape_catcher])
        self.pyflex.set_hideShapes(self.hide_shape)

        self.pyflex.step()

    def step(self):
        pourer_pos, pourer_angle = self.action[:3], self.action[3]

        pourer_prev = self.boxes_pourer
        self.boxes_pourer, _ = calc_container_boxes_FluidPour(
            pourer_pos,
            angle=pourer_angle,
            direction=np.array([0., 0., 1.]),
            size=self.pourer_size,
            border=self.border)

        catcher_prev = self.boxes_catcher
        self.boxes_catcher, _ = calc_container_boxes_FluidPour(
            self.catcher_pos,
            angle=np.deg2rad(0),
            direction=np.array([0., 0., 1.]),
            size=self.catcher_size,
            border=self.border)

        shape_states = np.zeros((len(self.boxes_pourer) + len(self.boxes_catcher), 14))

        # set shape state for pourer
        for idx_box in range(len(self.boxes_pourer)):
            center_prev = pourer_prev[idx_box][1]
            quat_prev = pourer_prev[idx_box][2]
            center = self.boxes_pourer[idx_box][1]
            quat = self.boxes_pourer[idx_box][2]

            shape_states[idx_box, :3] = center
            shape_states[idx_box, 3:6] = center_prev
            shape_states[idx_box, 6:10] = quat
            shape_states[idx_box, 10:] = quat_prev

        # set shape state for catcher
        offset = len(self.boxes_catcher)
        for idx_box in range(len(self.boxes_catcher)):
            center_prev = catcher_prev[idx_box][1]
            quat_prev = catcher_prev[idx_box][2]
            center = self.boxes_catcher[idx_box][1]
            quat = self.boxes_catcher[idx_box][2]

            shape_states[idx_box + offset, :3] = center
            shape_states[idx_box + offset, 3:6] = center_prev
            shape_states[idx_box + offset, 6:10] = quat
            shape_states[idx_box + offset, 10:] = quat_prev

        self.pyflex.set_shape_states(shape_states)

        self.pyflex.step()

    def clean(self):
        self.pyflex.clean()



class FluidManipFullEngine(PhysicsEngine):

    def __init__(self, args):
        super(FluidManipFullEngine, self).__init__()

        import pyflex
        self.pyflex = pyflex
        self.pyflex.set_screenWidth(args.screenWidth)
        self.pyflex.set_screenHeight(args.screenHeight)
        self.pyflex.init()

        self.args = args

    def init(self, scene_params=None, context=None):

        self.env_idx = 17

        if scene_params is None:
            table_height = 1.2
            table_size = 4.

            border = 0.02
            radius = 0.055

            dim_x_fluid_pourer = 10
            dim_y_fluid_pourer = 20
            dim_z_fluid_pourer = 10
            size_x_pourer = dim_x_fluid_pourer * radius - 0.06
            size_y_pourer = 1.2
            size_z_pourer = dim_z_fluid_pourer * radius - 0.06

            pourer_lim_x = [-0.7, -0.5]
            pourer_lim_z = [-0.35, 0.35]
            x_pourer = rand_float(pourer_lim_x[0], pourer_lim_x[1] - 0.1)
            y_pourer = 1.3 + table_height
            z_pourer = rand_float(pourer_lim_z[0], pourer_lim_z[1])
            x_fluid_pourer = x_pourer
            y_fluid_pourer = y_pourer - size_y_pourer / 2.
            z_fluid_pourer = z_pourer

            dim_x_fluid_catcher = 25
            dim_y_fluid_catcher = 5
            dim_z_fluid_catcher = 25
            size_x_catcher = dim_x_fluid_catcher * radius - 0.06
            size_y_catcher = 0.7
            size_z_catcher = dim_z_fluid_catcher * radius - 0.06

            x_catcher = 0.4
            y_catcher = size_y_catcher / 2. + border + table_height
            z_catcher = 0.
            x_fluid_catcher = x_catcher
            y_fluid_catcher = border + table_height
            z_fluid_catcher = z_catcher

            draw_mesh = 1

            scene_params = np.array([
                x_fluid_pourer - (dim_x_fluid_pourer - 1) / 2. * radius,
                y_fluid_pourer,
                z_fluid_pourer - (dim_z_fluid_pourer - 1) / 2. * radius,
                dim_x_fluid_pourer,
                dim_y_fluid_pourer,
                dim_z_fluid_pourer,
                x_fluid_catcher - (dim_x_fluid_catcher - 1) / 2. * radius,
                y_fluid_catcher,
                z_fluid_catcher - (dim_z_fluid_catcher - 1) / 2. * radius,
                dim_x_fluid_catcher,
                dim_y_fluid_catcher,
                dim_z_fluid_catcher,
                draw_mesh])

            context = [
                table_height, table_size,
                border, radius,
                pourer_lim_x, pourer_lim_z,
                x_pourer, y_pourer, z_pourer,
                size_x_pourer, size_y_pourer, size_z_pourer,
                x_catcher, y_catcher, z_catcher,
                size_x_catcher, size_y_catcher, size_z_catcher]



        self.scene_params = scene_params.copy()
        self.context = context.copy()

        table_height, table_size, \
                border, radius, \
                pourer_lim_x, pourer_lim_z, \
                x_pourer, y_pourer, z_pourer, \
                size_x_pourer, size_y_pourer, size_z_pourer, \
                x_catcher, y_catcher, z_catcher, \
                size_x_catcher, size_y_catcher, size_z_catcher = context

        print(scene_params)

        self.pyflex.set_scene(self.env_idx, scene_params, 0)
        self.pyflex.set_fluid_color(np.array([0.529, 0.808, 0.98, 0.0]))

        self.pyflex.set_floorScaleSize(0.3)


        # set pourer
        pourer_pos = np.array([x_pourer, y_pourer, z_pourer])
        pourer_size = np.array([size_x_pourer, size_y_pourer, size_z_pourer])

        ### set kuka robot

        self.kuka_helper = KukaFleXContainer()
        scaling = 4.5
        self.n_link_kuka = 8
        rest_pos = [-3.5, 0.14, 0.]
        rest_orn = [-0.70710678, 0., 0., 0.70710678]

        self.kuka_helper.add_kuka(rest_pos=rest_pos, rest_orn=rest_orn, scaling=scaling)

        color_kuka = np.array([
            [0.9, 0.9, 0.9],
            [0.9, 0.9, 0.9],
            [0.9, 0.9, 0.9],
            [1.0, 0.3, 0.0],
            [0.9, 0.9, 0.9],
            [0.9, 0.9, 0.9],
            [1.0, 0.3, 0.0],
            [0.9, 0.9, 0.9]])

        for i in range(self.n_link_kuka):
            hideShape = 0
            color = color_kuka[i]
            self.pyflex.add_mesh("assets/kuka_iiwa/meshes/link_%d.obj" % i, scaling, hideShape, color)

        pos, orn = calc_kuka_ee_state_FluidManip(
            pos=pourer_pos,
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=pourer_size,
            border=border)

        # refresh the kuka state history
        self.kuka_helper.set_ee_state(pos=pos, orn=orn)
        kuka_shape_states = self.kuka_helper.get_link_state()

        hide_shape_kuka = np.zeros(self.n_link_kuka)

        ### set table

        boxes = calc_table_shapes_FluidManip(
            table_size, border, table_height)

        self.n_table_shape = len(boxes)

        for i in range(self.n_table_shape):
            halfEdge = boxes[i][0]
            center = boxes[i][1]
            quat = boxes[i][2]

            hideShape = 0
            color = np.ones(3) * 0.9
            self.pyflex.add_box(halfEdge, center, quat, hideShape, color)

        hide_shape_table = np.zeros(self.n_table_shape)


        ### set container

        boxes_pourer, hide_shape_pourer = calc_container_boxes_FluidPour(
            pos=pourer_pos,
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=pourer_size,
            border=border)

        for i in range(len(boxes_pourer)):
            halfEdge = boxes_pourer[i][0]
            center = boxes_pourer[i][1]
            quat = boxes_pourer[i][2]

            hideShape = hide_shape_pourer[i]
            color = np.ones(3) * 0.9
            self.pyflex.add_box(halfEdge, center, quat, hideShape, color)

        # set catcher
        catcher_pos = np.array([x_catcher, y_catcher, z_catcher])
        catcher_size = np.array([size_x_catcher, size_y_catcher, size_z_catcher])

        boxes_catcher, hide_shape_catcher = calc_container_boxes_FluidPour(
            pos=catcher_pos,
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=catcher_size,
            border=border)

        for i in range(len(boxes_catcher)):
            halfEdge = boxes_catcher[i][0]
            center = boxes_catcher[i][1]
            quat = boxes_catcher[i][2]

            hideShape = hide_shape_catcher[i]
            color = np.ones(3) * 0.9
            self.pyflex.add_box(halfEdge, center, quat, hideShape, color)

        # record all necessary information
        self.table_height = table_height
        self.table_size = table_size

        self.pourer_lim_x = pourer_lim_x
        self.pourer_lim_z = pourer_lim_z

        self.border = border
        self.pourer_pos = pourer_pos
        self.catcher_pos = catcher_pos
        self.pourer_size = pourer_size
        self.catcher_size = catcher_size

        self.boxes_pourer = boxes_pourer
        self.boxes_catcher = boxes_catcher
        self.hide_shape = np.concatenate(
            [hide_shape_kuka, hide_shape_table, hide_shape_pourer, hide_shape_catcher])
        # self.pyflex.set_hideShapes(self.hide_shape)

        self.pyflex.step()

    def step(self):
        pourer_pos, pourer_angle = self.action[:3], self.action[3]

        pourer_prev = self.boxes_pourer
        self.boxes_pourer, _ = calc_container_boxes_FluidPour(
            pourer_pos,
            angle=pourer_angle,
            direction=np.array([0., 0., 1.]),
            size=self.pourer_size,
            border=self.border)

        catcher_prev = self.boxes_catcher
        self.boxes_catcher, _ = calc_container_boxes_FluidPour(
            self.catcher_pos,
            angle=np.deg2rad(0),
            direction=np.array([0., 0., 1.]),
            size=self.catcher_size,
            border=self.border)

        shape_states = np.zeros(
            (self.n_link_kuka + self.n_table_shape + \
            len(self.boxes_pourer) + len(self.boxes_catcher), 14))

        # set shape state for kuka
        pos, orn = calc_kuka_ee_state_FluidManip(
            pos=pourer_pos,
            angle=pourer_angle,
            direction=np.array([0., 0., 1.]),
            size=self.pourer_size,
            border=self.border)

        self.kuka_helper.set_ee_state(pos=pos, orn=orn)
        shape_states[:self.n_link_kuka] = self.kuka_helper.get_link_state()


        # set shape state for the table
        table_shapes = calc_table_shapes_FluidManip(
            self.table_size, self.border, self.table_height)

        offset = self.n_link_kuka
        for idx_box in range(self.n_table_shape):
            center = table_shapes[idx_box][1]
            quat = table_shapes[idx_box][2]
            shape_states[idx_box + offset, :3] = center
            shape_states[idx_box + offset, 3:6] = center
            shape_states[idx_box + offset, 6:10] = quat
            shape_states[idx_box + offset, 10:] = quat

        offset_base = self.n_link_kuka + self.n_table_shape

        # set shape state for pourer
        for idx_box in range(len(self.boxes_pourer)):
            center_prev = pourer_prev[idx_box][1]
            quat_prev = pourer_prev[idx_box][2]
            center = self.boxes_pourer[idx_box][1]
            quat = self.boxes_pourer[idx_box][2]

            shape_states[idx_box + offset_base, :3] = center
            shape_states[idx_box + offset_base, 3:6] = center_prev
            shape_states[idx_box + offset_base, 6:10] = quat
            shape_states[idx_box + offset_base, 10:] = quat_prev

        # set shape state for catcher
        offset = offset_base + len(self.boxes_catcher)
        for idx_box in range(len(self.boxes_catcher)):
            center_prev = catcher_prev[idx_box][1]
            quat_prev = catcher_prev[idx_box][2]
            center = self.boxes_catcher[idx_box][1]
            quat = self.boxes_catcher[idx_box][2]

            shape_states[idx_box + offset, :3] = center
            shape_states[idx_box + offset, 3:6] = center_prev
            shape_states[idx_box + offset, 6:10] = quat
            shape_states[idx_box + offset, 10:] = quat_prev

        self.pyflex.set_shape_states(shape_states)

        self.pyflex.step()

    def clean(self):
        self.pyflex.clean()






class FluidShakeWithIceEngine(PhysicsEngine):

    def __init__(self, args):
        super(FluidShakeWithIceEngine, self).__init__()

        import pyflex
        self.pyflex = pyflex
        self.pyflex.set_screenWidth(args.screenWidth)
        self.pyflex.set_screenHeight(args.screenHeight)
        self.pyflex.init()

        self.args = args

    def init(self, scene_params=None, context=None):

        self.env_idx = 8

        if scene_params is None:
            height = 2.5
            border = 0.025

            bar_position_y = 0.6
            bar_diameter = 0.04
            bar_length_y = 0.4
            bar_length_x = 0.2

            # fluid block
            # dim_x = rand_int(10, 12)
            # dim_y = rand_int(15, 20)
            # dim_z = rand_int(10, 12)
            dim_x = 16
            dim_y = 8
            dim_z = 16
            x_center = rand_float(-0.2, 0.2)
            z_center = rand_float(-0.2, 0.2)
            x = x_center - (dim_x - 1) / 2. * 0.055
            y = 0.055 / 2. + border + 0.01
            z = z_center - (dim_z - 1) / 2. * 0.055
            # box_dis_x = dim_x * 0.055 + rand_float(0., 0.3)
            # box_dis_z = dim_z * 0.055 + rand_float(0., 0.3)
            box_dis_x = dim_x * 0.055
            box_dis_z = dim_z * 0.055
            draw_mesh = 1

            box_info = [box_dis_x, box_dis_z, height, border,
                        bar_position_y, bar_diameter,
                        bar_length_y, bar_length_x]

            viscosity = 2.0

            # rigid block
            sx_r = 0.25
            sy_r = 0.25
            sz_r = 0.25
            px_r = x_center - sx_r / 2.
            py_r = y + dim_y * 0.052
            pz_r = z_center - sz_r / 2.
            invMass = 0.4

            scene_params = np.array([
                x, y, z, dim_x, dim_y, dim_z, viscosity,
                px_r, py_r, pz_r, sx_r, sy_r, sz_r, invMass,
                box_dis_x, box_dis_z, draw_mesh])

            context = [box_info, x_center, z_center]


        self.scene_params = scene_params.copy()
        self.context = context.copy()

        box_info, x_center, z_center = context

        print(scene_params)

        self.pyflex.set_scene(self.env_idx, scene_params, 0)
        self.pyflex.set_fluid_color(np.array([0.529, 0.808, 0.98, 0.0]))

        self.pyflex.set_floorScaleSize(0.3)

        box_dis_x, box_dis_z = box_info[:2]
        boxes = calc_box_init_FluidShake(box_dis_x, box_dis_z, box_info)

        for i in range(len(boxes)):
            halfEdge = boxes[i][0]
            center = boxes[i][1]
            quat = boxes[i][2]
            self.pyflex.add_box(halfEdge, center, quat)

        # make the container invisible
        hideShapes_off = np.array([0, 1, 1, 1, 1, 0, 0, 1, 1])
        hideShapes_on = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1])
        hideShapes_showBar = np.array([1, 1, 1, 1, 1, 0, 0, 1, 1])
        self.hide_shape = hideShapes_off
        self.pyflex.set_hideShapes(self.hide_shape)

        # record all necessary information
        self.x_box = x_center
        self.z_box = z_center
        self.box_info = box_info

        self.pyflex.step()

    def step(self):
        x_box, z_box = self.action

        x_box_last = self.x_box
        z_box_last = self.z_box

        shape_states_ = calc_shape_states_FluidShake(
            x_box, x_box_last, z_box, z_box_last, self.box_info)
        self.pyflex.set_shape_states(shape_states_)

        self.pyflex.step()

        # record the current box position
        self.x_box = x_box
        self.z_box = z_box

    def clean(self):
        self.pyflex.clean()









class FluidShakeWithIceFullEngine(PhysicsEngine):

    def __init__(self, args):
        super(FluidShakeWithIceFullEngine, self).__init__()

        import pyflex
        self.pyflex = pyflex
        self.pyflex.set_screenWidth(args.screenWidth)
        self.pyflex.set_screenHeight(args.screenHeight)
        self.pyflex.init()

        self.args = args

    def init(self, scene_params=None, context=None):

        self.env_idx = 8

        if scene_params is None:
            table_height = 0.1
            table_size = 4.

            height = 2.5
            border = 0.025

            bar_position_y = 0.6
            bar_diameter = 0.04
            bar_length_y = 0.4
            bar_length_x = 0.2

            # fluid block
            # dim_x = rand_int(10, 12)
            # dim_y = rand_int(15, 20)
            # dim_z = rand_int(10, 12)
            dim_x = 16
            dim_y = 8
            dim_z = 16
            x_center = rand_float(-0.2, 0.2)
            z_center = rand_float(-0.2, 0.2)
            x = x_center - (dim_x - 1) / 2. * 0.055
            y = 0.055 / 2. + border + 0.01 + table_height
            z = z_center - (dim_z - 1) / 2. * 0.055
            # box_dis_x = dim_x * 0.055 + rand_float(0., 0.3)
            # box_dis_z = dim_z * 0.055 + rand_float(0., 0.3)
            box_dis_x = dim_x * 0.055
            box_dis_z = dim_z * 0.055
            draw_mesh = 1

            box_info = [box_dis_x, box_dis_z, height, border,
                        bar_position_y, bar_diameter,
                        bar_length_y, bar_length_x,
                        table_height]

            viscosity = 2.0

            # rigid block
            sx_r = 0.25
            sy_r = 0.25
            sz_r = 0.25
            px_r = x_center - sx_r / 2.
            py_r = y + dim_y * 0.052
            pz_r = z_center - sz_r / 2.
            invMass = 0.4

            scene_params = np.array([
                x, y, z, dim_x, dim_y, dim_z, viscosity,
                px_r, py_r, pz_r, sx_r, sy_r, sz_r, invMass,
                box_dis_x, box_dis_z, draw_mesh])

            context = [table_height, table_size, box_info, x_center, z_center]


        self.scene_params = scene_params.copy()
        self.context = context.copy()

        table_height, table_size, box_info, x_center, z_center = context
        box_dis_x, box_dis_z, height, border, bar_position_y, bar_diameter, \
                bar_length_y, bar_length_x, table_height = box_info


        print(scene_params)

        self.pyflex.set_scene(self.env_idx, scene_params, 0)
        self.pyflex.set_fluid_color(np.array([0.529, 0.808, 0.98, 0.0]))

        self.pyflex.set_floorScaleSize(0.01)

        ### set kuka robot
        palm_x = 0.2
        palm_y = 0.25
        palm_z = 0.45
        finger_x = 0.06
        finger_y = 0.25
        finger_z = 0.06
        finger_dis = 0.06

        self.kuka_helper = KukaFleXContainer()
        scaling = 4.
        self.n_link_kuka = 8
        rest_pos = [2.6, 0.14, 0.]
        rest_orn = [-0.70710678, 0., 0., 0.70710678]

        self.kuka_helper.add_kuka(rest_pos=rest_pos, rest_orn=rest_orn, scaling=scaling)

        color_kuka = np.array([
            [0.9, 0.9, 0.9],
            [0.9, 0.9, 0.9],
            [0.9, 0.9, 0.9],
            [1.0, 0.3, 0.0],
            [0.9, 0.9, 0.9],
            [0.9, 0.9, 0.9],
            [1.0, 0.3, 0.0],
            [0.9, 0.9, 0.9]])

        for i in range(self.n_link_kuka):
            hideShape = 0
            color = color_kuka[i]
            self.pyflex.add_mesh("assets/kuka_iiwa/meshes/link_%d.obj" % i, scaling, hideShape, color)

        pos, orn = calc_kuka_ee_state_FluidShakeWithIce(
            pos=(x_center, table_height + bar_length_y + palm_y + finger_y, z_center),
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=(box_dis_x, bar_length_y, box_dis_z),
            border=border)

        # refresh the kuka state history
        self.kuka_helper.set_ee_state(pos=pos, orn=orn)
        kuka_shape_states = self.kuka_helper.get_link_state()

        hide_shape_kuka = np.zeros(self.n_link_kuka)

        ### set gripper
        self.gripper_info = [palm_x, palm_y, palm_z, finger_x, finger_y, finger_z, finger_dis]

        self.n_gripper_shape = 3
        color = np.ones(3) * 0.9
        hideShape = 0
        center = np.zeros(3)
        quat = np.array([1., 0., 0., 0.])

        self.pyflex.add_box(
            np.array([palm_x, palm_y, palm_z]) / 2., center, quat, hideShape, color)
        self.pyflex.add_box(
            np.array([finger_x, finger_y, finger_z]) / 2., center, quat, hideShape, color)
        self.pyflex.add_box(
            np.array([finger_x, finger_y, finger_z]) / 2., center, quat, hideShape, color)

        ### set table

        boxes = calc_table_shapes_FluidShakeWithIce(
            table_size, border, table_height)

        self.n_table_shape = len(boxes)

        for i in range(self.n_table_shape):
            halfEdge = boxes[i][0]
            center = boxes[i][1]
            quat = boxes[i][2]

            hideShape = 0
            color = np.ones(3) * 0.9
            self.pyflex.add_box(halfEdge, center, quat, hideShape, color)

        hide_shape_table = np.zeros(self.n_table_shape)

        ### set container

        # make the container invisible
        hideShapes_off = np.array([0, 1, 1, 1, 1, 0, 0, 1, 1])
        hideShapes_on = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1])
        hideShapes_showBar = np.array([1, 1, 1, 1, 1, 0, 0, 1, 1])

        box_dis_x, box_dis_z = box_info[:2]
        boxes = calc_box_init_FluidShake(box_dis_x, box_dis_z, box_info)

        self.n_container_shape = len(boxes)

        for i in range(len(boxes)):
            halfEdge = boxes[i][0]
            center = boxes[i][1]
            quat = boxes[i][2]

            hideShape = hideShapes_off[i]
            color = np.ones(3) * 0.9
            self.pyflex.add_box(halfEdge, center, quat, hideShape, color)

        self.hide_shape = hideShapes_off
        # self.pyflex.set_hideShapes(self.hide_shape)

        # record all necessary information
        self.table_height = table_height
        self.table_size = table_size

        self.x_box = x_center
        self.z_box = z_center
        self.box_info = box_info

        self.pyflex.step()

    def step(self):
        x_box, z_box = self.action

        x_box_last = self.x_box
        z_box_last = self.z_box

        shape_states = np.zeros(
            (self.n_link_kuka + self.n_gripper_shape + self.n_table_shape + self.n_container_shape, 14))

        # set shape state for kuka
        palm_x, palm_y, palm_z, finger_x, finger_y, finger_z, finger_dis = self.gripper_info
        box_dis_x, box_dis_z, height, border, bar_position_y, bar_diameter, \
                bar_length_y, bar_length_x, table_height = self.box_info

        pos, orn = calc_kuka_ee_state_FluidShakeWithIce(
            pos=(x_box, self.table_height + bar_length_y + palm_y + finger_y, z_box),
            angle=0.,
            direction=np.array([0., 0., 1.]),
            size=(box_dis_x, bar_length_y, box_dis_z),
            border=border)

        self.kuka_helper.set_ee_state(pos=pos, orn=orn)
        shape_states[:self.n_link_kuka] = self.kuka_helper.get_link_state()

        # set shape state for gripper
        gripper_shapes = calc_gripper_shapes_FluidShakeWithIce(
            x_box, x_box_last, z_box, z_box_last, self.box_info, self.gripper_info)
        shape_states[self.n_link_kuka:self.n_link_kuka + self.n_gripper_shape] = gripper_shapes

        # set shape state for the table
        table_shapes = calc_table_shapes_FluidManip(
            self.table_size, border, self.table_height)

        offset = self.n_link_kuka + self.n_gripper_shape
        for idx_box in range(self.n_table_shape):
            center = table_shapes[idx_box][1]
            quat = table_shapes[idx_box][2]
            shape_states[idx_box + offset, :3] = center
            shape_states[idx_box + offset, 3:6] = center
            shape_states[idx_box + offset, 6:10] = quat
            shape_states[idx_box + offset, 10:] = quat

        # set shape state for the container
        shape_states[self.n_link_kuka + self.n_gripper_shape + self.n_table_shape:] = \
                calc_shape_states_FluidShake(
                    x_box, x_box_last, z_box, z_box_last, self.box_info)
        self.pyflex.set_shape_states(shape_states)

        self.pyflex.step()

        # record the current box position
        self.x_box = x_box
        self.z_box = z_box

    def clean(self):
        self.pyflex.clean()

