import numpy as np
# from robots.foot_trajectory_generator import FTG

class TerrainsTask(object):
    def __init__(self,
                 weight=1,

                 ):
        self._env = None
        self._robot = None
        # self.v_xy = None
        # self.v_pr = None
        # self.v_xy_tar = None
        # self.av_tar = None
        self._weight = weight
        self.foot_z_min = -0.324
        self.v_tar = 0.2
        self.foot_z_tar = 0.06
        self.foot_score_info = None
        self.foot_surrounding_height = 0
        # reward function parameters
        # self.k_i = 0
        # self.max_k = 5*10**6
        return

    def __call__(self, env):
        return self.reward(env)

    def reset(self, env):
        self._env = env
        self._robot = self._env.robot
        self.foot_score_info = None
        self.foot_surrounding_height = 0

        # self.v_xy = None
        # self.v_xy_tar = None
        # self.v_pr = None
        # self.av_tar = None
        return

    def LinearVelocityReward(self, v_xy, v_xy_tar, v_pr):
        r_lv = 0
        if v_xy_tar.any():
            if v_pr<np.array([self.v_tar]):
                r_lv = np.exp(-30*((v_pr-self.v_tar)**2))
            elif v_pr>=np.array([self.v_tar]):
                r_lv = np.array([1])
        else:
            r_lv = np.exp(-20*(np.sum(np.square(v_xy))))
        return r_lv

    # def AngularVelocityReward(self, av_z, av_tar, av_pr):
    #     r_av =0
    #     # av_tar = self._robot.GetDesiredTurningDirection()
    #     # print("1",av_tar)
    #     if av_tar.any():
    #         # _, av = self._robot.GetBaseVelocity()
    #         # av_z = av[2]
    #         # av_pr = np.dot(av_z, av_tar)
    #         if av_pr<np.array([self.v_tar]):
    #             # r_av = np.exp(-1.5*((av_pr-self.v_tar)**2))
    #             # print("av_pr:",av_pr)
    #             r_av = np.exp(-1.5 * ((av_pr - self.v_tar) ** 2))
    #         elif av_pr>=np.array([self.v_tar]):
    #             r_av = np.array([1])
    #     else:
    #         r_av = np.exp(-1.5 *(np.sum(np.square(av_z))))
    #         # print(r_av)
    #     return np.squeeze(r_av)

    def YawReward(self,yaw, v_xy_tar):
        r = 0
        if v_xy_tar.any():
            r = np.exp(-2*np.abs((np.cos(yaw)-v_xy_tar[0]))-np.abs((np.sin(yaw)-v_xy_tar[1])))
        return r

    def BaseMotionReward(self,v_xy,v_pr,v_xy_tar,av_xy):

        # self.v_pr = np.dot(v_xy, v_xy_tar)
        r_b1 = np.exp(-1.5*np.sum(np.square(v_xy-v_pr*v_xy_tar)))
        # _, av_xy = self._robot.GetBaseVelocity()[:2]
        r_b2 = np.exp(-1.5*np.sum(np.square(av_xy)))

        # r_b3 = np.exp(-1.5 * np.sum(np.square(rpy)))
        r_b = r_b1 + r_b2 # + r_b3
        return r_b

    # def BaseRPYReward(self, rpy,v_xy_tar):
    #     r_b = 0
    #     if v_xy_tar.any():
    #         phi = self._robot.FTG.phi
    #         T,T_all = self._robot.FTG.T,self._robot.FTG.T_all
    #         leg_stand_list = []
    #         # print("------")
    #         for i in range(len(phi)):
    #             # if np.pi <= phi[i] < 2 * np.pi:
    #             if T <= phi[i] < T_all:
    #                 leg_stand_list.append(i)
    #
    #         # foot_pos = self._robot.GetFootPositionsInBaseFrame()
    #         foot_pos = np.reshape(self._robot.GetXYZFootPositionsInWorldFrame(),newshape=(4,3))
    #         if (0 in leg_stand_list and 3 in leg_stand_list) or \
    #             (1 in leg_stand_list and 2 in leg_stand_list):
    #             if 0 in leg_stand_list and 3 in leg_stand_list:
    #                 foot_pos_xz_1 = np.array([foot_pos[0,0],foot_pos[0,2]])
    #                 foot_pos_xz_2 = np.array([foot_pos[3, 0], foot_pos[3, 2]])
    #             elif 1 in leg_stand_list and 2 in leg_stand_list:
    #                 foot_pos_xz_1 = np.array([foot_pos[1, 0], foot_pos[1, 2]])
    #                 foot_pos_xz_2 = np.array([foot_pos[2, 0], foot_pos[2, 2]])
    #             foot_pos_xz = foot_pos_xz_1 - foot_pos_xz_2
    #             print("fz:",foot_pos_xz)
    #             desired_p = -np.arccos((np.dot(foot_pos_xz,v_xy_tar)/np.linalg.norm(foot_pos_xz)))
    #             # print("--------")
    #             # print(rpy[1])
    #             print(desired_p)
    #             desired_rp = np.array([0,desired_p])
    #             rp = rpy[:2]
    #             r_b = np.exp(-1.5 * np.sum(np.square(rp-desired_rp)))
    #         else:
    #             r_b=0
    #     return r_b
    def BaseRPYReward(self, rpy):

        phi = self._robot.FTG.phi
        T,T_all = self._robot.FTG.T,self._robot.FTG.T_all
        leg_stand_list = []
        # print("------")
        for i in range(len(phi)):
            # if np.pi <= phi[i] < 2 * np.pi:
            if T <= phi[i] < T_all:
                leg_stand_list.append(i)

        # foot_pos = self._robot.GetFootPositionsInBaseFrame()
        foot_pos = np.reshape(self._robot.GetXYZFootPositionsInWorldFrame(),newshape=(4,3))
        if (0 in leg_stand_list and 3 in leg_stand_list) or \
            (1 in leg_stand_list and 2 in leg_stand_list):
            if 0 in leg_stand_list and 3 in leg_stand_list:
                foot_pos_xz_1 = np.array([foot_pos[0,:]])
                foot_pos_xz_2 = np.array([foot_pos[3,:]])
            elif 1 in leg_stand_list and 2 in leg_stand_list:
                foot_pos_xz_1 = np.array([foot_pos[1,:]])
                foot_pos_xz_2 = np.array([foot_pos[2,:]])
            foot_pos_xz = np.squeeze(foot_pos_xz_1 - foot_pos_xz_2)
            # print(foot_pos_xz)
            foot_pos_xz[1] = 0
            desired_p = -(np.pi/2 - np.arccos((np.dot(foot_pos_xz, np.array([0,0,1])) / np.linalg.norm(foot_pos_xz))))
            # print("d_p:",desired_p)
            desired_rp = np.array([0,desired_p])
            rp = rpy[:2]
            r_b = np.exp(-1.5 * np.sum(np.square(rp-desired_rp)))
        else:
            r_b = 0
        return r_b

    def FootClearanceReward(self):
        r_fc= 0
        # H_scan = np.zeros(shape=9)
        # H_scan = self._robot._sensors[-1].get_height_scan()

        foot_pos_z = self._robot.GetFootPositionsInBaseFrame()[:,2]
        foot_r_f = foot_pos_z-(self.foot_z_min)
        I_swing, F_clear= [], []
        phi = self._robot.FTG.phi
        T,T_all = self._robot.FTG.T,self._robot.FTG.T_all

        for i in range(len(phi)):
            # if np.pi <= phi[i] < 2 * np.pi:
            if 0 <= phi[i] < T:
                I_swing.append(i)
                # if foot_r_f[i] > np.max([np.max(H_scan[9*i:9*(i+1)]), self.foot_z_tar]):
                # if foot_r_f[i] > self.foot_z_tar:
                # print(foot_r_f[i], np.max(self.foot_surrounding_height[9*i:9*(i+1)]))
                # input()
                if foot_r_f[i] > np.max(self.foot_surrounding_height[9*i:9*(i+1)]):
                    F_clear.append(i)
        if len(I_swing) != 0:
            r_fc = len(F_clear) / len(I_swing)
        return r_fc

    def BodyCollisionReward(self):
        r_bc =0
        foot_contact_num = np.sum(self._robot.GetFootContacts())
        body_contact_num = len(self._robot.all_contacts)-foot_contact_num
        if foot_contact_num!=0:
            r_bc = body_contact_num/foot_contact_num
        return -r_bc

    def TargetSmoothnessReward(self):
        tfp_2,tfp_1,tfp = self._robot.last_tfp_2, self._robot.last_tfp_1, self._robot.now_tfp
        r_s = -np.linalg.norm((tfp-2*tfp_1+tfp_2), ord=2)
        return r_s

    def TorqueReward(self):
        tor = self._robot.GetMotorTorques()
        r_to = -np.sum(np.abs(tor))
        return r_to

    # def FootScoreReward(self):
    #     r_foot = 0
    #     foot_score, score_length, score_width = self.foot_score_info
    #     foot_length = score_length*score_width
    #     phi = self._robot.FTG.phi
    #     T,T_all = self._robot.FTG.T,self._robot.FTG.T_all
    #
    #     for i in range(len(phi)):
    #         # if np.pi <= phi[i] < 2 * np.pi:
    #         if T <= phi[i] < T_all:
    #             r_foot += np.mean(foot_score[foot_length*i:foot_length*(i+1)])
    #             # r_foot += self.foot_score[i]
    #     return np.exp(-10*r_foot)

    def reward(self, env):
        """Get the reward without side effects."""
        # self.foot_score_info = env.foot_score_info
        # self.foot_surrounding_height = env.foot_surrounding_height
        # del env

        v, _ = self._robot.GetBaseVelocity()
        v_xy = v[:2]
        # print("v_xy:",v_xy)
        v_xy_tar = self._robot.GetDesiredDirection()
        # print("v_xy_tar:",v_xy_tar)
        v_pr = np.dot(v_xy, v_xy_tar)

        rpy = self._robot.GetTrueBaseRollPitchYaw()

        # rp = rpy[:2]

        # av_tar = self._robot.GetDesiredTurningDirection()
        _, av = self._robot.GetBaseVelocity()
        av_xy = av[:2]
        # av_z = av[2]
        # av_pr = np.dot(av_z, av_tar)
        # print("av_pr",av_pr)
        # rpy = self._GetTrueBaseRollPitchYaw()
        # rpy_reward = np.exp(-1*np.sum((np.abs(rpy))))

        reward = 0.1*self.LinearVelocityReward(v_xy=v_xy,v_xy_tar=v_xy_tar,v_pr=v_pr) \
                +0.05*self.YawReward(yaw=rpy[2],v_xy_tar=v_xy_tar) \
               +0.005*self.BaseMotionReward(v_xy=v_xy,v_pr=v_pr,v_xy_tar=v_xy_tar,av_xy=av_xy) \
                +0.02*self.BodyCollisionReward() \
                +0.025*self.TargetSmoothnessReward() \
                +2*(10**(-5))*self.TorqueReward() \
                + 0.05*self.BaseRPYReward(rpy=rpy)

        # +0.01 * self.FootClearanceReward() \
            # + 0.01 * self.FootScoreReward() \
        # +0.01*self.AngularVelocityReward(av_z=av_z,av_tar=av_tar,av_pr=av_pr) \
        # +0.005*self.FootScoreReward() \
        # print("---------")
        # print(self.LinearVelocityReward(v_xy=v_xy,v_xy_tar=v_xy_tar,v_pr=v_pr))
        # print(self.YawReward(yaw=rpy[2],v_xy_tar=v_xy_tar))
        # print(self.FootScoreReward())
        # input()
        # print(v_xy)
        # print(self.LinearVelocityReward(v_xy=v_xy,v_xy_tar=v_xy_tar,v_pr=v_pr) )
        # input()
        # 0.05 * self._GetVelReward() \
        # + 0.05 * rpy_reward \
        # print(np.squeeze(reward))

        # print("----------------------")
        # print(self.LinearVelocityReward(v_xy=v_xy,v_xy_tar=v_xy_tar,v_pr=v_pr))
        # print(self.YawReward(yaw=rpy[2],v_xy_tar=v_xy_tar))
        # print(self.AngularVelocityReward(av_z=av_z,av_tar=av_tar,av_pr=av_pr))
        # print(self.BaseMotionReward(v_xy=v_xy,v_pr=v_pr,v_xy_tar=v_xy_tar,av_xy=av_xy,rpy=rpy))
        #
        # print(self.FootClearanceReward())
        # print(self.BodyCollisionReward())
        # print(self.TargetSmoothnessReward())
        # print(self.TorqueReward())
        # print(self.FootScoreReward())
        # print(self.BaseRPYReward(rpy=rpy))
        # input()
        reward = np.squeeze(reward)

        # print(reward)
        # k = self.k_i / self.max_k if self.k_i<self.max_k else 1
        # self.k_i += 1
        # print(self.k_i,k)

        # if self.done(env):
        #     reward = reward - 0.001*k
        return reward

    def done(self, env):
        """Checks if the episode is over."""

        done = self._env.robot.GetBasePosition()[1]>0.5
        roll, pitch, yaw = self._GetTrueBaseRollPitchYaw()
        if np.abs(roll) > np.pi / 4 or np.abs(pitch) > np.pi / 4 or np.abs(yaw) > np.pi / 4:
            done = True
        # done = False
        return done

    def _GetVelReward(self):
        return self._env.robot.GetVelReward()

    def _GetBasePosition(self):
        return self._env.robot.GetBasePosition()

    def _GetTrueBaseRollPitchYaw(self):
        return self._env.robot.GetTrueBaseRollPitchYaw()
    # def done(self, env):
    #     """Checks if the episode is over."""
    #     del env
    #     done = self._robot._step_counter//self._robot._action_repeat>3000
    #     for i,angle in enumerate(self._robot.GetTrueBaseRollPitchYaw()[:2]):
    #         if np.abs(angle) > np.pi / 4:
    #             done = True
    #     # done = False
    #     return done

    # def _get_pybullet_client(self):
    #     """Get bullet client from the environment"""
    #     return self._env._pybullet_client

    # def _get_num_joints(self):
    #     """Get the number of joints in the character's body."""
    #     pyb = self._get_pybullet_client()
    #     return pyb.getNumJoints(self._env.robot.quadruped)

    # def _get_pos_vel_info(self):
    #     pyb = self._get_pybullet_client()
    #     quadruped = self._env.robot.quadruped
    #     self.body_pos = pyb.getBasePositionAndOrientation(quadruped)[0]  # 3 list: position list of 3 floats
    #     self.body_ori = pyb.getBasePositionAndOrientation(quadruped)[
    #         1]  # 4 list: orientation as list of 4 floats in [x,y,z,w] order
    #     self.body_lin_vel = pyb.getBaseVelocity(quadruped)[0]  # 3 list: linear velocity [x,y,z]
    #     self.body_ang_vel = pyb.getBaseVelocity(quadruped)[1]  # 3 list: angular velocity [wx,wy,wz]
    #     self.joint_pos = []  # float: the position value of this joint
    #     self.joint_vel = []  # float: the velocity value of this joint
    #     for i in range(12):
    #         self.joint_pos.append(pyb.getJointState(quadruped, i)[0])
    #         self.joint_vel.append(pyb.getJointState(quadruped, i)[1])

    # def _get_reward(self):
    #     return self._env.robot.GetReward()
    #
    # def _GetEnergyConsumptionPerControlStep(self):
    #     return np.clip(self._env.robot.GetEnergyConsumptionPerControlStep(), 0, 100)
    #
    # def _GetTrueBaseRollPitchYaw(self):
    #     return self._env.robot.GetTrueBaseRollPitchYaw()
    #
    # def _GetPos(self):
    #     return self._env.robot.GetPos()

