from os import path as osp
import numpy as np
from legged_gym.envs.go1.go1_field_config import Go1FieldCfg, Go1FieldCfgPPO
from legged_gym.utils.helpers import merge_dict

class Go1JumpCfg( Go1FieldCfg ):

    class init_state( Go1FieldCfg.init_state ):
        pos = [0., 0., 0.45]

    #### uncomment this to train non-virtual terrain
    class sensor( Go1FieldCfg.sensor ):
        class proprioception( Go1FieldCfg.sensor.proprioception ):
            latency_range = [0.04-0.0025, 0.04+0.0075]
    #### uncomment the above to train non-virtual terrain
    
    class terrain( Go1FieldCfg.terrain ):
        max_init_terrain_level = 10
        border_size = 5
        slope_treshold = 20.
        curriculum = True

        BarrierTrack_kwargs = merge_dict(Go1FieldCfg.terrain.BarrierTrack_kwargs, dict(
            options= [
                "jump",
            ],
            track_width= 1.6,
            track_block_length= 1.6,
            wall_thickness= 0.2,
            jump= dict(
                # height= (0.38, 0.46),
                height= (0.2, 0.46),
                depth= (0.1, 0.2),
                fake_offset= 0.0, # [m] an offset that make the robot easier to get into the obstacle
                jump_down_prob= 0., # probability of jumping down use it in non-virtual terrain
            ),
            virtual_terrain= True, # Change this to False for real terrain
            no_perlin_threshold= 0.06,
            randomize_obstacle_order= True,
            n_obstacles_per_track= 3,
        ))

        TerrainPerlin_kwargs = merge_dict(Go1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
            zScale= [0.05, 0.12],
        ))

    class commands( Go1FieldCfg.commands ):
        class ranges( Go1FieldCfg.commands.ranges ):
            lin_vel_x = [0.5, 1.5]
            lin_vel_y = [0.0, 0.0]
            ang_vel_yaw = [0., 0.]

    class control( Go1FieldCfg.control ):
        computer_clip_torque = False

    class asset( Go1FieldCfg.asset ):
        penalize_contacts_on = ["base", "thigh", "calf", "imu"]
        terminate_after_contacts_on = ["base"]

    class termination( Go1FieldCfg.termination ):
        # additional factors that determines whether to terminates the episode
        termination_terms = [
            "roll",
            "pitch",
            "z_high",
            "out_of_track",
        ]
        timeout_at_finished = True
        timeout_at_border = True

    class domain_rand( Go1FieldCfg.domain_rand ):
        class com_range( Go1FieldCfg.domain_rand.com_range ):
            z = [-0.1, 0.1]
        
        init_base_pos_range = dict(
            x= [0.2, 0.8],
            y= [-0.25, 0.25],
        )
        init_base_rot_range = dict(
            roll= [-0.1, 0.1],
            pitch= [-0.1, 0.1],
        )

        # push_robots = True
        push_robots = False

    class rewards( Go1FieldCfg.rewards ):
        class scales:
            tracking_ang_vel = 0.08
            # alive = 10.
            tracking_world_vel = 5. ######## 1. always ########
            legs_energy_substeps = -6.e-5 # -6e-6
            jump_x_vel_cond = 0.1 ######## 1. segment ########
            # hip_pos = -5.
            front_hip_pos = -5.
            rear_hip_pos = -0.000001
            penetrate_depth = -6e-3
            penetrate_volume = -6e-4
            exceed_dof_pos_limits = -8e-4 # -8e-4
            exceed_torque_limits_l1norm = -2. # -8e-2 # -6e-2 
            action_rate = -0.1
            delta_torques = -1.e-5
            dof_acc = -1e-7
            torques = -4e-4 # 2e-3
            yaw_abs = -0.1
            lin_pos_y = -0.4
            collision = -1.
            sync_all_legs_cond = -0.3 # -0.6
            dof_error = -0.06
        only_positive_rewards = False # False
        soft_dof_pos_limit = 0.8
        max_contact_force = 100.
        tracking_sigma = 0.3

    class noise( Go1FieldCfg.noise ):
        add_noise = False

    class curriculum( Go1FieldCfg.curriculum ):
        penetrate_volume_threshold_harder = 6000
        penetrate_volume_threshold_easier = 12000
        penetrate_depth_threshold_harder = 600
        penetrate_depth_threshold_easier = 1200

logs_root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))), "logs")
class Go1JumpCfgPPO( Go1FieldCfgPPO ):
    class algorithm( Go1FieldCfgPPO.algorithm ):
        entropy_coef = 0.01
        clip_min_std = 0.21
    
    class runner( Go1FieldCfgPPO.runner):
        experiment_name = "go1"
        task_name = 'go1_jump'
        resume = False
        # resume = True
        load_run = "{Your traind walking model directory}"
        load_run = "Oct30_11-11-43_Skills_jump_pEnergySubsteps6e-05_rTrackVel5._pY-8e-01_pTorqueExceed1.2e+00_pTorque4e-04_pDTorques1e-06_propDelay0.04-0.05_noPropNoise_pushRobot_gamma0.999_noTanh_noComputerClip_jumpRange0.2-0.5_allowNegativeReward_fromOct29_20-17-31"
        load_run = osp.join(logs_root, "field_go1_noTanh_oracle", "Oct30_13-00-12_Skills_jump_pEnergySubsteps6e-05_rTrackVel5._pY-4e-01_pTorqueExceed1.8e+00_pTorque4e-04_pDTorques1e-05_propDelay0.04-0.05_noPropNoise_noPush_gamma0.999_noTanh_noComputerClip_jumpRange0.2-0.5_allowNegativeReward_fromOct29_20-17-31")

        run_name = "".join(["Skills_",
        ("Multi" if len(Go1JumpCfg.terrain.BarrierTrack_kwargs["options"]) > 1 else (Go1JumpCfg.terrain.BarrierTrack_kwargs["options"][0] if Go1JumpCfg.terrain.BarrierTrack_kwargs["options"] else "PlaneWalking")),
        ("_pEnergySubsteps" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.legs_energy_substeps, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "legs_energy_substeps", -2e-5) != -2e-5 else ""),
        # ("_rTrackVel" + np.format_float_positional(Go1JumpCfg.rewards.scales.tracking_world_vel) if getattr(Go1JumpCfg.rewards.scales, "tracking_world_vel", 0.) != 0. else ""),
        # ("_pLinVel" + np.format_float_positional(-Go1JumpCfg.rewards.scales.world_vel_l2norm) if getattr(Go1JumpCfg.rewards.scales, "world_vel_l2norm", 0.) != 0.0 else ""),
        # ("_rAlive{:.1f}".format(Go1JumpCfg.rewards.scales.alive) if getattr(Go1JumpCfg.rewards.scales, "alive", 0.) != 0. else ""),
        # ("_pPenD{:.0e}".format(Go1JumpCfg.rewards.scales.penetrate_depth) if getattr(Go1JumpCfg.rewards.scales, "penetrate_depth", 0.) < 0. else ""),
        # ("_pYaw{:.0e}".format(Go1JumpCfg.rewards.scales.yaw_abs) if getattr(Go1JumpCfg.rewards.scales, "yaw_abs", 0.) < 0. else ""),
        # ("_pY{:.0e}".format(Go1JumpCfg.rewards.scales.lin_pos_y) if getattr(Go1JumpCfg.rewards.scales, "lin_pos_y", 0.) < 0. else ""),
        # ("_pDofLimit{:.0e}".format(Go1JumpCfg.rewards.scales.exceed_dof_pos_limits) if getattr(Go1JumpCfg.rewards.scales, "exceed_dof_pos_limits", 0.) < 0. else ""),
        # ("_pDofErr" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.dof_error, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "dof_error", 0.) != 0. else ""),
        # ("_pDofErrCond" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.dof_error_cond, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "dof_error_cond", 0.) != 0. else ""),
        # ("_pHipPos" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.hip_pos, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "hip_pos", 0.) != 0. else ""),
        ("_pFHipPos" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.front_hip_pos, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "front_hip_pos", 0.) != 0. else ""),
        ("_pRHipPos" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.rear_hip_pos, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "rear_hip_pos", 0.) != 0. else ""),
        # ("_pTorqueExceedI" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.exceed_torque_limits_i, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "exceed_torque_limits_i", 0.) != 0. else ""),
        ("_pTorqueExceed" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.exceed_torque_limits_l1norm, precision=1, trim="-") if Go1JumpCfg.rewards.scales.exceed_torque_limits_l1norm != -8e-1 else ""),
        # ("_pContact" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.feet_contact_forces, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "feet_contact_forces", 0.) != 0. else ""),
        # ("_pJumpSameLegs{:.1f}".format(-Go1JumpCfg.rewards.scales.sync_legs_cond) if getattr(Go1JumpCfg.rewards.scales, "sync_legs_cond", 0.) != 0. else ""),
        # ("_pDofAcc" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.dof_acc, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "dof_acc", 0.) != 0. else ""),
        # ("_pTorque" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.torques, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "torques", 0.) != 0. else ""),
        # ("_rJumpXVel" + np.format_float_positional(Go1JumpCfg.rewards.scales.jump_x_vel_cond) if getattr(Go1JumpCfg.rewards.scales, "jump_x_vel_cond", 0.) != 0. else "_noJumpBonous"),
        # ("_pCollision" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.collision, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "collision", 0.) < 0. else ""),
        # ("_pSyncSymLegs" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.sync_all_legs_cond, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "sync_all_legs_cond", 0.) != 0. else ""),
        # ("_pZVel" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.lin_vel_z, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "lin_vel_z", 0.) != 0. else ""),
        # ("_pAngXYVel" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.ang_vel_xy, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "ang_vel_xy", 0.) != 0. else ""),
        # ("_pDTorques" + np.format_float_scientific(-Go1JumpCfg.rewards.scales.delta_torques, precision=1, trim="-") if getattr(Go1JumpCfg.rewards.scales, "delta_torques", 0.) != 0. else ""),
        ("_propDelay{:.2f}-{:.2f}".format(
            Go1JumpCfg.sensor.proprioception.latency_range[0],
            Go1JumpCfg.sensor.proprioception.latency_range[1],
        )),
        # ("_softDof{:.1f}".format(Go1JumpCfg.rewards.soft_dof_pos_limit) if Go1JumpCfg.rewards.soft_dof_pos_limit != 0.9 else ""),
        # ("_timeoutFinished" if getattr(Go1JumpCfg.termination, "timeout_at_finished", False) else ""),
        # ("_noPitchTerm" if "pitch" not in Go1JumpCfg.termination.termination_terms else ""),
        ("_noPropNoise" if not Go1JumpCfg.noise.add_noise else "_withPropNoise"),
        ("_noPush" if not Go1JumpCfg.domain_rand.push_robots else "_pushRobot"),
        ("_minStd0.21"),
        ("_entropy0.01"),
        # ("_gamma0.999"),
        ("_noTanh"),
        # ("_zeroResetAction" if Go1JumpCfg.init_state.zero_actions else ""),
        # ("_actionClip" + Go1JumpCfg.normalization.clip_actions_method if getattr(Go1JumpCfg.normalization, "clip_actions_method", "") != "" else ""),
        # ("_noDelayActObs" if not Go1JumpCfg.sensor.proprioception.delay_action_obs else ""),
        ("_noComputerClip" if not Go1JumpCfg.control.computer_clip_torque else ""),
        ("_jumpRange{:.1f}-{:.1f}".format(*Go1JumpCfg.terrain.BarrierTrack_kwargs["jump"]["height"]) if Go1JumpCfg.terrain.BarrierTrack_kwargs["jump"]["height"] != (0.1, 0.6) else ""),
        # ("_noCurriculum" if not Go1JumpCfg.terrain.curriculum else ""),
        # ("_allowNegativeReward" if not Go1JumpCfg.rewards.only_positive_rewards else ""),
        ("_from" + "_".join(load_run.split("/")[-1].split("_")[:2]) if resume else "_noResume"),
        ])
        max_iterations = 20000
        save_interval = 200

