import numpy as np


class RunningTask(object):
    def __init__(self,
                 weight=1,
                 pose_weight=0.5,
                 velocity_weight=0.05,
                 end_effector_weight=0.2,
                 forward_reward_pos_weight=1,
                 forward_reward_vel_weight=1,
                 forward_reward_vel_y_weight=0.5,

                 ctrl_cost_weight=0.01,
                 healthy_reward_weight=1,
                 control_energy_weight=0.01,

                 roll_pitch_weight=0.01,
                 yaw_weight=0.001,

                 survive_reward=0.1,
                 healthy_reward=1,
                 ):
        self._env = None
        self._weight = weight

        # reward function parameters
        self._pose_weight = pose_weight
        self._velocity_weight = velocity_weight
        self._end_effector_weight = end_effector_weight
        self._forward_reward_pos_weight = forward_reward_pos_weight
        self._forward_reward_vel_weight = forward_reward_vel_weight
        self._forward_reward_vel_y_weight = forward_reward_vel_y_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._control_energy_weight = control_energy_weight
        self._healthy_reward_weight = healthy_reward_weight

        self._roll_pitch_weight = roll_pitch_weight
        self._yaw_weight = yaw_weight

        self._survive_reward = survive_reward
        self._healthy_reward = healthy_reward

        self.body_pos = None
        self.body_ori = None
        self.body_lin_vel = None
        self.body_ang_vel = None
        self.joint_pos = None
        self.joint_vel = None
        self.pos_reward = None
        self.vel_reward = None
        self.num = 0

        return

    def __call__(self, env):
        return self.reward(env)

    def reset(self, env):
        self._env = env
        self.num = 0
        return

    # zhy
    def is_healthy(self):
        self._get_pos_vel_info()
        is_healthy = (np.isfinite(self.body_pos).all() and 0.1 <= self.body_pos[2] <= 0.5)
        # print("healthy:", is_healthy)
        # is_healthy = True
        return is_healthy

    def healthy_reward(self):
        return self._healthy_reward * float(self.is_healthy())

    def reward(self, env):
        """Get the reward without side effects."""
        del env

        self._get_pos_vel_info()
        forward_reward_pos, forward_reward_vel, forward_reward_vel_y, ctrl_cost = self._get_reward()
        healthy_reward = self.healthy_reward()
        control_energy = self._GetEnergyConsumptionPerControlStep()

        pos_reward = np.exp(-10*np.sum(np.abs(np.array([0]) - self.body_pos[1]) / 2))
        vel_reward = np.exp(-np.sum(np.abs(np.array([2,0,0]) - self.body_lin_vel[0:3]) / 2))
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))
        healthy_reward = self.healthy_reward()
        roll, pitch, yaw = self._GetTrueBaseRollPitchYaw()
        control_energy = sigmoid(control_energy)
        # print("body_pos:",self.body_pos)
        # input()
        roll, pitch, yaw = self._GetTrueBaseRollPitchYaw()

        rpy_reward = np.exp(-10*(self._roll_pitch_weight * (np.abs(roll) + np.abs(pitch))
                              + self._yaw_weight * np.abs(yaw)))

        reward = pos_reward + 5 * vel_reward + rpy_reward
        self.num += 1
        # print("pos :{}".format(pos_reward))
        # print("vel :{}".format(vel_reward))
        # print("rpy :{}".format(rpy_reward))
        # print("reward: {} \n".format(reward))
        return reward * self._weight

    def done(self, env):
        """Checks if the episode is over."""
        # del env
        # self._get_pos_vel_info()
        # # done = self.body_pos[0] < -2 or np.abs(self.body_pos[1]) > 2 or self.body_pos[2] < 0.2 \
        # #        or self.pos_reward == 'nan' or self.vel_reward == 'nan'
        # done = self.body_pos[2] < 0.2

        done = not self.is_healthy() or self.num>2000

        for i in self._GetTrueBaseRollPitchYaw():
            if abs(i) > np.pi / 5:
                done = True
        # done = False
        return done

    def _get_pybullet_client(self):
        """Get bullet client from the environment"""
        return self._env._pybullet_client

    def _get_num_joints(self):
        """Get the number of joints in the character's body."""
        pyb = self._get_pybullet_client()
        return pyb.getNumJoints(self._env.robot.quadruped)

    def _get_pos_vel_info(self):
        pyb = self._get_pybullet_client()
        quadruped = self._env.robot.quadruped
        self.body_pos = pyb.getBasePositionAndOrientation(quadruped)[0]  # 3 list: position list of 3 floats
        self.body_ori = pyb.getBasePositionAndOrientation(quadruped)[
            1]  # 4 list: orientation as list of 4 floats in [x,y,z,w] order
        self.body_lin_vel = pyb.getBaseVelocity(quadruped)[0]  # 3 list: linear velocity [x,y,z]
        self.body_ang_vel = pyb.getBaseVelocity(quadruped)[1]  # 3 list: angular velocity [wx,wy,wz]
        self.joint_pos = []  # float: the position value of this joint
        self.joint_vel = []  # float: the velocity value of this joint
        for i in range(12):
            self.joint_pos.append(pyb.getJointState(quadruped, i)[0])
            self.joint_vel.append(pyb.getJointState(quadruped, i)[1])

    def _get_reward(self):
        return self._env.robot.GetReward()

    def _GetEnergyConsumptionPerControlStep(self):
        return np.clip(self._env.robot.GetEnergyConsumptionPerControlStep(), 0, 100)

    def _GetTrueBaseRollPitchYaw(self):
        return self._env.robot.GetTrueBaseRollPitchYaw()

    def _GetPos(self):
        return self._env.robot.GetPos()
