#!/home/unitree/agility_ziwenz_venv/bin/python
import os
import os.path as osp
import json
import numpy as np
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple

import rospy
from std_msgs.msg import Float32MultiArray
from sensor_msgs.msg import Image
import ros_numpy

from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
from rsl_rl.utils.utils import get_obs_slice

@torch.no_grad()
def handle_forward_depth(ros_msg, model, publisher, output_resolution, device):
    """ The callback function to handle the forward depth and send the embedding through ROS topic """
    buf = ros_numpy.numpify(ros_msg).astype(np.float32)
    forward_depth_buf = resize2d(
        torch.from_numpy(buf).unsqueeze(0).unsqueeze(0).to(device),
        output_resolution,
    )
    embedding = model(forward_depth_buf)
    ros_data = embedding.reshape(-1).cpu().numpy().astype(np.float32)
    publisher.publish(Float32MultiArray(data= ros_data.tolist()))

class StandOnlyModel(torch.nn.Module):
    def __init__(self, action_scale, dof_pos_scale, tolerance= 0.1, delta= 0.1):
        rospy.loginfo("Using stand only model, please make sure the proprioception is 48 dim.")
        rospy.loginfo("Using stand only model, -36 to -24 must be joint position.")
        super().__init__()
        if isinstance(action_scale, (tuple, list)):
            self.register_buffer("action_scale", torch.tensor(action_scale))
        else:
            self.action_scale = action_scale
        if isinstance(dof_pos_scale, (tuple, list)):
            self.register_buffer("dof_pos_scale", torch.tensor(dof_pos_scale))
        else:
            self.dof_pos_scale = dof_pos_scale
        self.tolerance = tolerance
        self.delta = delta

    def forward(self, obs):
        joint_positions = obs[..., -36:-24] / self.dof_pos_scale
        diff_large_mask = torch.abs(joint_positions) > self.tolerance
        target_positions = torch.zeros_like(joint_positions)
        target_positions[diff_large_mask] = joint_positions[diff_large_mask] - self.delta * torch.sign(joint_positions[diff_large_mask])
        return torch.clip(
            target_positions / self.action_scale,
            -1.0, 1.0,
        )
    
    def reset(self, *args, **kwargs):
        pass

def load_walk_policy(env, model_dir):
    """ Load the walk policy from the model directory """
    if model_dir == None:
        model = StandOnlyModel(
            action_scale= env.action_scale,
            dof_pos_scale= env.obs_scales["dof_pos"],
        )
        policy = torch.jit.script(model)

    else:
        with open(osp.join(model_dir, "config.json"), "r") as f:
            config_dict = json.load(f, object_pairs_hook= OrderedDict)
        obs_components = config_dict["env"]["obs_components"]
        privileged_obs_components = config_dict["env"].get("privileged_obs_components", obs_components)
        model = getattr(modules, config_dict["runner"]["policy_class_name"])(
            num_actor_obs= env.get_num_obs_from_components(obs_components),
            num_critic_obs= env.get_num_obs_from_components(privileged_obs_components),
            num_actions= 12,
            **config_dict["policy"],
        )
        model_names = [i for i in os.listdir(model_dir) if i.startswith("model_")]
        model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
        state_dict = torch.load(osp.join(model_dir, model_names[-1]), map_location= "cpu")
        model.load_state_dict(state_dict["model_state_dict"])
        model_action_scale = torch.tensor(config_dict["control"]["action_scale"]) if isinstance(config_dict["control"]["action_scale"], (tuple, list)) else torch.tensor([config_dict["control"]["action_scale"]])[0]
        if not (torch.is_tensor(model_action_scale) and (model_action_scale == env.action_scale).all()):
            action_rescale_ratio = model_action_scale / env.action_scale
            print("walk_policy action scaling:", action_rescale_ratio.tolist())
        else:
            action_rescale_ratio = 1.0
        memory_module = model.memory_a
        actor_mlp = model.actor
        @torch.jit.script
        def policy_run(obs):
            recurrent_embedding = memory_module(obs)
            actions = actor_mlp(recurrent_embedding.squeeze(0))
            return actions
        if (torch.is_tensor(action_rescale_ratio) and (action_rescale_ratio == 1.).all()) \
            or (not torch.is_tensor(action_rescale_ratio) and action_rescale_ratio == 1.):
            policy = policy_run
        else:
            policy = lambda x: policy_run(x) * action_rescale_ratio
    
    return policy, model

def standup_procedure(env, ros_rate, angle_tolerance= 0.05, kp= None, kd= None, device= "cpu"):
    rospy.loginfo("Robot standing up, please wait ...")

    target_pos = torch.zeros((1, 12), device= device, dtype= torch.float32)
    while not rospy.is_shutdown():
        dof_pos = [env.low_state_buffer.motorState[env.dof_map[i]].q for i in range(12)]
        diff = [env.default_dof_pos[i].item() - dof_pos[i] for i in range(12)]
        direction = [1 if i > 0 else -1 for i in diff]
        if all([abs(i) < angle_tolerance for i in diff]):
            break
        print("max joint error (rad):", max([abs(i) for i in diff]), end= "\r")
        for i in range(12):
            target_pos[0, i] = dof_pos[i] + direction[i] * angle_tolerance if abs(diff[i]) > angle_tolerance else env.default_dof_pos[i]
        env.publish_legs_cmd(target_pos,
            kp= kp,
            kd= kd,
        )
        ros_rate.sleep()

    rospy.loginfo("Robot stood up! press R1 on the remote control to continue ...")
    while not rospy.is_shutdown():
        if env.low_state_buffer.wirelessRemote.btn.components.R1:
            break
        if env.low_state_buffer.wirelessRemote.btn.components.L2 or env.low_state_buffer.wirelessRemote.btn.components.R2:
            env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
            rospy.signal_shutdown("Controller send stop signal, exiting")
            exit(0)
        env.publish_legs_cmd(env.default_dof_pos.unsqueeze(0), kp= kp, kd= kd)
        ros_rate.sleep()
    rospy.loginfo("Robot standing up procedure finished!")

class SkilledA1Real(UnitreeA1Real):
    """ Some additional methods to help the execution of skill policy """
    def __init__(self, *args,
            skill_mode_threhold= 0.1,
            skill_vel_range= [0.0, 1.0],
            **kwargs,
        ):
        self.skill_mode_threhold = skill_mode_threhold
        self.skill_vel_range = skill_vel_range
        super().__init__(*args, **kwargs)

    def is_skill_mode(self):
        if self.move_by_wireless_remote:
            return self.low_state_buffer.wirelessRemote.ry > self.skill_mode_threhold
        else:
            # Not implemented yet
            return False

    def update_low_state(self, ros_msg):
        self.low_state_buffer = ros_msg
        if self.move_by_wireless_remote and ros_msg.wirelessRemote.ry > self.skill_mode_threhold:
            skill_vel = (self.low_state_buffer.wirelessRemote.ry - self.skill_mode_threhold) / (1.0 - self.skill_mode_threhold)
            skill_vel *= self.skill_vel_range[1] - self.skill_vel_range[0]
            skill_vel += self.skill_vel_range[0]
            self.command_buf[0, 0] = skill_vel
            self.command_buf[0, 1] = 0.
            self.command_buf[0, 2] = 0.
            return
        return super().update_low_state(ros_msg)

def main(args):
    log_level = rospy.DEBUG if args.debug else rospy.INFO
    rospy.init_node("a1_legged_gym_" + args.mode, anonymous= True, log_level= log_level)

    with open(osp.join(args.logdir, "config.json"), "r") as f:
        config_dict = json.load(f, object_pairs_hook= OrderedDict)
    duration = config_dict["sim"]["dt"] * config_dict["control"]["decimation"] # in sec
    # config_dict["control"]["stiffness"]["joint"] -= 2.5 # kp

    model_device = torch.device("cpu") if args.mode == "upboard" else torch.device("cuda")

    unitree_real_env = SkilledA1Real(
        robot_namespace= args.namespace,
        cfg= config_dict,
        forward_depth_topic= "/visual_embedding" if args.mode == "upboard" else "/camera/depth/image_rect_raw",
        forward_depth_embedding_dims= config_dict["policy"]["visual_latent_size"] if args.mode == "upboard" else None,
        move_by_wireless_remote= True,
        skill_vel_range= config_dict["commands"]["ranges"]["lin_vel_x"],
        model_device= model_device,
        # extra_cfg= dict(
        #     motor_strength= torch.tensor([
        #         1., 1./0.9, 1./0.9,
        #         1., 1./0.9, 1./0.9,
        #         1., 1., 1.,
        #         1., 1., 1.,
        #     ], dtype= torch.float32, device= model_device, requires_grad= False),
        # ),
    )

    model = getattr(modules, config_dict["runner"]["policy_class_name"])(
        num_actor_obs= unitree_real_env.num_obs,
        num_critic_obs= unitree_real_env.num_privileged_obs,
        num_actions= 12,
        obs_segments= unitree_real_env.obs_segments,
        privileged_obs_segments= unitree_real_env.privileged_obs_segments,
        **config_dict["policy"],
    )
    config_dict["terrain"]["measure_heights"] = False
    # load the model with the latest checkpoint
    model_names = [i for i in os.listdir(args.logdir) if i.startswith("model_")]
    model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
    state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
    model.load_state_dict(state_dict["model_state_dict"])
    model.to(model_device)
    model.eval()

    rospy.loginfo("duration: {}, motor Kp: {}, motor Kd: {}".format(
        duration,
        config_dict["control"]["stiffness"]["joint"],
        config_dict["control"]["damping"]["joint"],
    ))
    rospy.loginfo("[Env] torque limit: {:.1f}".format(unitree_real_env.torque_limits.mean().item()))
    rospy.loginfo("[Env] action scale: {:.1f}".format(unitree_real_env.action_scale))
    rospy.loginfo("[Env] motor strength: {}".format(unitree_real_env.motor_strength))

    if args.mode == "jetson":
        embeding_publisher = rospy.Publisher(
            args.namespace + "/visual_embedding",
            Float32MultiArray,
            queue_size= 1,
        )
        # extract and build the torch ScriptFunction
        visual_encoder = model.visual_encoder
        visual_encoder = torch.jit.script(visual_encoder)

        forward_depth_subscriber = rospy.Subscriber(
            args.namespace + "/camera/depth/image_rect_raw",
            Image,
            partial(handle_forward_depth,
                model= visual_encoder,
                publisher= embeding_publisher,
                output_resolution= config_dict["sensor"]["forward_camera"].get(
                    "output_resolution",
                    config_dict["sensor"]["forward_camera"]["resolution"],
                ),
                device= model_device,
            ),
            queue_size= 1,
        )
        rospy.spin()
    elif args.mode == "upboard":
        # extract and build the torch ScriptFunction
        memory_module = model.memory_a
        actor_mlp = model.actor
        @torch.jit.script
        def policy(obs):
            recurrent_embedding = memory_module(obs)
            actions = actor_mlp(recurrent_embedding.squeeze(0))
            return actions
        
        walk_policy, walk_model = load_walk_policy(unitree_real_env, args.walkdir)

        using_walk_policy = True # switch between skill policy and walk policy
        unitree_real_env.start_ros()
        unitree_real_env.wait_untill_ros_working()
        rate = rospy.Rate(1 / duration)
        with torch.no_grad():
            if not args.debug:
                standup_procedure(unitree_real_env, rate,
                    angle_tolerance= 0.1,
                    kp= 50,
                    kd= 1.,
                    device= model_device,
                )
            while not rospy.is_shutdown():
                # inference_start_time = rospy.get_time()
                # check remote controller and decide which policy to use
                if unitree_real_env.is_skill_mode():
                    if using_walk_policy:
                        rospy.loginfo_throttle(0.1, "switch to skill policy")
                        using_walk_policy = False
                        model.reset()
                else:
                    if not using_walk_policy:
                        rospy.loginfo_throttle(0.1, "switch to walk policy")
                        using_walk_policy = True
                        walk_model.reset()
                if not using_walk_policy:
                    obs = unitree_real_env.get_obs()
                    actions = policy(obs)
                else:
                    walk_obs = unitree_real_env._get_proprioception_obs()
                    actions = walk_policy(walk_obs)
                unitree_real_env.send_action(actions)
                # unitree_real_env.send_action(torch.zeros((1, 12)))
                # inference_duration = rospy.get_time() - inference_start_time
                # rospy.loginfo("inference duration: {:.3f}".format(inference_duration))
                # rospy.loginfo("visual_latency: %f", rospy.get_time() - unitree_real_env.forward_depth_embedding_stamp.to_sec())
                # motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
                # rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
                rate.sleep()
                if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.down:
                    rospy.loginfo_throttle(0.1, "model reset")
                    model.reset()
                    walk_model.reset()
                if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
                    unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
                    rospy.signal_shutdown("Controller send stop signal, exiting")
    elif args.mode == "full":
        # extract and build the torch ScriptFunction
        visual_obs_slice = get_obs_slice(unitree_real_env.obs_segments, "forward_depth")
        visual_encoder = model.visual_encoder
        memory_module = model.memory_a
        actor_mlp = model.actor
        @torch.jit.script
        def policy(observations: torch.Tensor, obs_start: int, obs_stop: int, obs_shape: Tuple[int, int, int]):
            visual_latent = visual_encoder(
                observations[..., obs_start:obs_stop].reshape(-1, *obs_shape)
            ).reshape(1, -1)
            obs = torch.cat([
                observations[..., :obs_start],
                visual_latent,
                observations[..., obs_stop:],
            ], dim= -1)
            recurrent_embedding = memory_module(obs)
            actions = actor_mlp(recurrent_embedding.squeeze(0))
            return actions

        unitree_real_env.start_ros()
        unitree_real_env.wait_untill_ros_working()
        rate = rospy.Rate(1 / duration)
        with torch.no_grad():
            while not rospy.is_shutdown():
                # inference_start_time = rospy.get_time()
                obs = unitree_real_env.get_obs()
                actions = policy(obs,
                    obs_start= visual_obs_slice[0].start.item(),
                    obs_stop= visual_obs_slice[0].stop.item(),
                    obs_shape= visual_obs_slice[1],
                )
                unitree_real_env.send_action(actions)
                # inference_duration = rospy.get_time() - inference_start_time
                motor_temperatures = [motor_state.temperature for motor_state in unitree_real_env.low_state_buffer.motorState]
                rospy.loginfo_throttle(10, " ".join(["motor_temperatures:"] + ["{:d},".format(t) for t in motor_temperatures[:12]]))
                rate.sleep()
                if unitree_real_env.low_state_buffer.wirelessRemote.btn.components.L2 or unitree_real_env.low_state_buffer.wirelessRemote.btn.components.R2:
                    unitree_real_env.publish_legs_cmd(unitree_real_env.default_dof_pos.unsqueeze(0), kp= 20, kd= 0.5)
                    rospy.signal_shutdown("Controller send stop signal, exiting")
    else:
        rospy.logfatal("Unknown mode, exiting")

if __name__ == "__main__":
    """ The script to run the A1 script in ROS.
    It's designed as a main function and not designed to be a scalable code.
    """
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--namespace",
        type= str,
        default= "/a112138",                    
    )
    parser.add_argument("--logdir",
        type= str,
        help= "The log directory of the trained model",
    )
    parser.add_argument("--walkdir",
        type= str,
        help= "The log directory of the walking model, not for the skills.",
        default= None,
    )
    parser.add_argument("--mode",
        type= str,
        help= "The mode to determine which computer to run on.",
        choices= ["jetson", "upboard", "full"],                
    )
    parser.add_argument("--debug",
        action= "store_true",
    )

    args = parser.parse_args()
    main(args)