import sys
import os

curPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(curPath)[0]
sys.path.append(rootPath)

import argparse
from real_uti import a1_robot
import numpy as np
import time
import pybullet_data
from pybullet_utils import bullet_client
import pybullet
import torch
from envs.robots.foot_trajectory_generator import FTG

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.cuda.is_available())


def robot_stand_up(robot):
    initial_motor_angle = np.array([0., 0.67, -1.25] * 4)
    for t in range(1000):
        blend_ratio = np.minimum(t / 700., 1)
        if t == 0:
            action = (1 - blend_ratio
                      ) * np.array([0, 1.4, -2.8] * 4) + blend_ratio * initial_motor_angle
        else:
            action = (1 - blend_ratio
                      ) * robot.GetMotorAngles() + blend_ratio * initial_motor_angle
        robot.Step(action)
        time.sleep(0.001)


def robot_lie_down(robot):
    lie_down_time = 3
    default_motor_angles = np.array([0, 1.4, -2.8] * 4)
    for t in np.arange(0, 3, time_step * _num_action_repeat):
        # print(self.GetMotorAngles())
        current_motor_angles = robot.GetMotorAngles()
        blend_ratio = min(t / lie_down_time, 1)
        action = blend_ratio * current_motor_angles + (
                1 - blend_ratio) * default_motor_angles
        robot._Step(action)
        # time.sleep(time_step * _num_action_repeat)
        time.sleep(0.03)

def robot_stand(robot):
    lie_down_time = 3
    default_motor_angles = np.array([0, 0.9, -1.8] * 4)
    for t in np.arange(0, 3, time_step * _num_action_repeat):
        # print(self.GetMotorAngles())
        current_motor_angles = robot.GetMotorAngles()
        blend_ratio = min(t / lie_down_time, 1)
        action = blend_ratio * current_motor_angles + (
                1 - blend_ratio) * default_motor_angles
        robot._Step(action)
        # time.sleep(time_step * _num_action_repeat)
        time.sleep(0.001)

def IsDone(robot):
    """Checks if the episode is over."""

    done = np.abs(robot.GetBasePosition()[1]) > 0.5
    roll, pitch, yaw = robot.GetTrueBaseRollPitchYaw()
    if np.abs(roll) > np.pi / 4 or np.abs(pitch) > np.pi / 4 or np.abs(yaw) > np.pi / 4:
        done = True
    # done = False
    return done

def load_model(path, itr):
    actor = torch.load(path + f"actor_{itr}.pt")
    critic = torch.load(path + f"critic_{itr}.pt")
    vae = torch.load(path + f"vae_{itr}.pt")
    return actor,critic,vae


def select_action(state,actor,critic, vae):
    state = torch.from_numpy(np.array(state,dtype=np.float32))
    with torch.no_grad():
        state = torch.FloatTensor(state.reshape(1, -1)).repeat(100, 1).to(device)
        action = actor(state, vae.decode(state))
        q1 = critic.q1(state, action)
        ind = q1.argmax(0)
    return action[ind].cpu().data.numpy().flatten()

_NUM_SIMULATION_ITERATION_STEPS = 300
_num_action_repeat = 8
time_step = 0.0025
_num_bullet_solver_iterations = int(_NUM_SIMULATION_ITERATION_STEPS /
                                    _num_action_repeat)

# time_sleep = 0.006
# kp= 80.
# kd = 1.3
time_sleep = 0.006
kp= 85.
kd = 1.
ABDUCTION_P_GAIN = kp   # 50
ABDUCTION_D_GAIN = kd
HIP_P_GAIN = kp  # 150
HIP_D_GAIN = kd
KNEE_P_GAIN = kp
KNEE_D_GAIN = kd
motor_kps = np.array([ABDUCTION_P_GAIN, HIP_P_GAIN, KNEE_P_GAIN] * 4)
motor_kds = np.array([ABDUCTION_D_GAIN, HIP_D_GAIN, KNEE_D_GAIN] * 4)

if __name__ == "__main__":
    isRealRobot = False

    ftg = FTG()
    desired_yaw = np.array([1, 0])

    if isRealRobot:
        # Construct sim env and real robot
        p = bullet_client.BulletClient(connection_mode=pybullet.DIRECT)
        p.setPhysicsEngineParameter(numSolverIterations=30)
        p.setTimeStep(time_step)
        p.setGravity(0, 0, -9.8)
        p.setPhysicsEngineParameter(enableConeFriction=0)
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        # p.loadURDF("plane.urdf")
        robot = a1_robot.A1Robot(pybullet_client=p,
                                 urdf_filename="envs/a1/urdf/a1.urdf",
                                 enable_action_interpolation=True,
                                 action_repeat=8,
                                 time_step=0.0025,
                                 motor_kps=motor_kps,
                                 motor_kds=motor_kds,
                                 FTG=ftg,
                                 desired_yaw=desired_yaw,
                                 )
    else:
        from envs.robots import a1
        p = bullet_client.BulletClient(connection_mode=pybullet.GUI)
        p.setPhysicsEngineParameter(numSolverIterations=30)
        p.setTimeStep(time_step)
        p.setGravity(0, 0, -9.8)
        p.setPhysicsEngineParameter(enableConeFriction=0)
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        # p.loadURDF("plane.urdf")
        p.configureDebugVisualizer(
            pybullet.COV_ENABLE_GUI, False)

        p.resetDebugVisualizerCamera(1.0,0.,-30.,[0, 0, 0])

        from envs.terrains.terrain_generator import Complex_terrain
        terrain_list = ["Plane", "Hills", "Steps", "Stairs UP", "Stairs DOWN", "Stairs MIX"]
        terrain_id = np.random.randint(low=0, high=len(terrain_list))
        _terrain = Complex_terrain(terrain_type=terrain_list[2], _pybullet_client=p,
                                   # env_time_counter_total=0
                                   )
        world_dict = {"terrain": _terrain.build}
        height_map = _terrain.get_height_map()
        terrain_scale = _terrain.scale

        # p = bullet_client.BulletClient(connection_mode=pybullet.GUI)
        robot = a1.A1(p,
                      urdf_filename="envs/a1/urdf/a1.urdf",
                      enable_action_interpolation=True,
                      time_step=0.0025,
                      action_repeat=8,
                      height_map=height_map,
                      terrain_scale=_terrain.scale,
                      FTG=ftg,
                      desired_yaw=desired_yaw,
                      )
        # task = terrains_task.TerrainsTask()
        # task.reset(robot)

    actor, critic, vae = load_model(path="models/", itr=700000.0)

    robot.ReceiveObservation()
    obs = robot.get_obs()

    T=300
    _last_frame_time = 0
    for i in range(T):

        time_spent = time.time() - _last_frame_time
        _last_frame_time = time.time()
        time_to_sleep = _num_action_repeat*time_step - time_spent
        if time_to_sleep > 0:
            time.sleep(time_to_sleep)

        time.sleep(time_sleep)

        if not isRealRobot:
            base_pos = robot.GetBasePosition()
            [yaw, pitch,
             dist] = p.getDebugVisualizerCamera()[8:11]
            p.resetDebugVisualizerCamera(dist, yaw, pitch, base_pos)

        # action_student = student_model.predict(obs_student)
        # action, _states = model.predict(obs, deterministic=True)
        action = select_action(obs, actor, critic, vae)
        # action = action_student
        robot.Step(action)

        # time.sleep(_num_action_repeat*time_step)
        # x_vel.append(robot.GetBaseVelocity()[0])

        # if isRealRobot == False:
        #     reward_lsit.append(task.reward(robot))

        # time.sleep(0.005)

        # time.sleep(robot.time_step * robot._action_repeat)

        obs = robot.get_obs()

        # obs_robot = obs
        # obs_robot_ = np.reshape(obs_robot, newshape=(1, 1, -1))
        # obs_student = np.concatenate([obs_student, obs_robot_], axis=1)[:, 1:, :]

        if isRealRobot == False and IsDone(robot):
            robot.Reset(False)

    if isRealRobot:
        robot_stand(robot)


    robot.Terminate()