import os
import os.path as osp
import numpy as np
import torch
import json
from functools import partial
from collections import OrderedDict

from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules

import rospy
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from sensor_msgs.msg import Image
import ros_numpy

import pyrealsense2 as rs

def get_encoder_script(logdir):
    with open(osp.join(logdir, "config.json"), "r") as f:
        config_dict = json.load(f, object_pairs_hook= OrderedDict)

    model_device = torch.device("cuda")

    unitree_real_env = UnitreeA1Real(
        robot_namespace= "a112138",
        cfg= config_dict,
        forward_depth_topic= "", # this env only computes parameters to build the model
        forward_depth_embedding_dims= None,
        model_device= model_device,
    )

    model = getattr(modules, config_dict["runner"]["policy_class_name"])(
        num_actor_obs= unitree_real_env.num_obs,
        num_critic_obs= unitree_real_env.num_privileged_obs,
        num_actions= 12,
        obs_segments= unitree_real_env.obs_segments,
        privileged_obs_segments= unitree_real_env.privileged_obs_segments,
        **config_dict["policy"],
    )
    model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
    model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
    state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
    model.load_state_dict(state_dict["model_state_dict"])
    model.to(model_device)
    model.eval()

    visual_encoder = model.visual_encoder
    script = torch.jit.script(visual_encoder)

    return script, model_device

def get_input_filter(args):
    """ This is the filter different from the simulator, but try to close the gap. """
    with open(osp.join(args.logdir, "config.json"), "r") as f:
        config_dict = json.load(f, object_pairs_hook= OrderedDict)
    image_resolution = config_dict["sensor"]["forward_camera"].get(
        "output_resolution",
        config_dict["sensor"]["forward_camera"]["resolution"],
    )
    depth_range = config_dict["sensor"]["forward_camera"].get(
        "depth_range",
        [0.0, 3.0],
    )
    depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
    crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
    crop_far = args.crop_far * 1000

    def input_filter(depth_image: torch.Tensor,
            crop_top: int,
            crop_bottom: int,
            crop_left: int,
            crop_right: int,
            crop_far: float,
            depth_min: int,
            depth_max: int,
            output_height: int,
            output_width: int,
        ):
        """ depth_image must have shape [1, 1, H, W] """
        depth_image = depth_image[:, :,
            crop_top: -crop_bottom-1,
            crop_left: -crop_right-1,
        ]
        depth_image[depth_image > crop_far] = depth_max
        depth_image = torch.clip(
            depth_image,
            depth_min,
            depth_max,
        ) / (depth_max - depth_min)
        depth_image = resize2d(depth_image, (output_height, output_width))
        return depth_image
    # input_filter = torch.jit.script(input_filter)

    return partial(input_filter,
        crop_top= crop_top,
        crop_bottom= crop_bottom,
        crop_left= crop_left,
        crop_right= crop_right,
        crop_far= crop_far,
        depth_min= depth_range[0],
        depth_max= depth_range[1],
        output_height= image_resolution[0],
        output_width= image_resolution[1],
    ), depth_range

def get_started_pipeline(
        height= 480,
        width= 640,
        fps= 30,
        enable_rgb= False,
    ):
    # By default, rgb is not used.
    pipeline = rs.pipeline()
    config = rs.config()
    config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
    if enable_rgb:
        config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
    profile = pipeline.start(config)

    # build the sensor filter
    hole_filling_filter = rs.hole_filling_filter(2)
    spatial_filter = rs.spatial_filter()
    spatial_filter.set_option(rs.option.filter_magnitude, 5)
    spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
    spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
    spatial_filter.set_option(rs.option.holes_fill, 4)
    temporal_filter = rs.temporal_filter()
    temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
    temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
    # decimation_filter = rs.decimation_filter()
    # decimation_filter.set_option(rs.option.filter_magnitude, 2)

    def filter_func(frame):
        frame = hole_filling_filter.process(frame)
        frame = spatial_filter.process(frame)
        frame = temporal_filter.process(frame)
        # frame = decimation_filter.process(frame)
        return frame

    return pipeline, filter_func

def main(args):
    rospy.init_node("a1_legged_gym_jetson")

    input_filter, depth_range = get_input_filter(args)
    model_script, model_device = get_encoder_script(args.logdir)
    with open(osp.join(args.logdir, "config.json"), "r") as f:
        config_dict = json.load(f, object_pairs_hook= OrderedDict)
    if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
        refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
        ros_rate = rospy.Rate(1.0 / refresh_duration)
        rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
    else:
        ros_rate = rospy.Rate(args.fps)

    rs_pipeline, rs_filters = get_started_pipeline(
        height= args.height,
        width= args.width,
        fps= args.fps,
        enable_rgb= args.enable_rgb,
    )

    # gyro_pipeline = rs.pipeline()
    # gyro_config = rs.config()
    # gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
    # gyro_profile = gyro_pipeline.start(gyro_config)

    embedding_publisher = rospy.Publisher(
        args.namespace + "/visual_embedding",
        Float32MultiArrayStamped,
        queue_size= 1,
    )

    if args.enable_vis:
        depth_image_publisher = rospy.Publisher(
            args.namespace + "/camera/depth/image_rect_raw",
            Image,
            queue_size= 1,
        )
        network_input_publisher = rospy.Publisher(
            args.namespace + "/camera/depth/network_input_raw",
            Image,
            queue_size= 1,
        )
        if args.enable_rgb:
            rgb_image_publisher = rospy.Publisher(
                args.namespace + "/camera/color/image_raw",
                Image,
                queue_size= 1,
            )

    rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
    rospy.loginfo("ROS, model, realsense have been initialized.")
    if args.enable_vis:
        rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
    try:
        embedding_msg = Float32MultiArrayStamped()
        embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
        frame_got = False
        while not rospy.is_shutdown():
            # Wait for the depth image
            frames = rs_pipeline.wait_for_frames(int( \
                config_dict["sensor"]["forward_camera"]["latency_range"][1] \
                 * 1000)) # ms
            embedding_msg.header.stamp = rospy.Time.now()
            depth_frame = frames.get_depth_frame()
            if not depth_frame:
                continue
            if not frame_got:
                frame_got = True
                rospy.loginfo("Realsense frame recieved. Sending embeddings...")
            if args.enable_rgb:
                color_frame = frames.get_color_frame()
                # Use this branch to log the time when image is acquired
                if args.enable_vis and not color_frame is None:
                    color_frame = np.asanyarray(color_frame.get_data())
                    rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
                    rgb_image_msg.header.stamp = rospy.Time.now()
                    rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
                    rgb_image_publisher.publish(rgb_image_msg)

            # Process the depth image and publish
            depth_frame = rs_filters(depth_frame)
            depth_image_ = np.asanyarray(depth_frame.get_data())
            depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
            depth_image = input_filter(depth_image)
            with torch.no_grad():
                depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
            embedding_msg.header.seq += 1
            embedding_msg.data = depth_embedding.tolist()
            embedding_publisher.publish(embedding_msg)
            
            # Publish the acquired image if needed
            if args.enable_vis:
                depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
                depth_image_msg.header.stamp = rospy.Time.now()
                depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
                depth_image_publisher.publish(depth_image_msg)
                network_input_np = (\
                    depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
                    + depth_range[0]
                ).astype(np.uint16)
                network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
                network_input_msg.header.stamp = rospy.Time.now()
                network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
                network_input_publisher.publish(network_input_msg)

            ros_rate.sleep()
    finally:
        rs_pipeline.stop()

if __name__ == "__main__":
    """ This script is designed to load the model and process the realsense image directly
    from realsense SDK without realsense ROS wrapper
    """
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--namespace",
        type= str,
        default= "/a112138",                    
    )
    parser.add_argument("--logdir",
        type= str,
        help= "The log directory of the trained model",
    )
    parser.add_argument("--height",
        type= int,
        default= 270,
        help= "The height of the realsense image",
    )
    parser.add_argument("--width",
        type= int,
        default= 480,
        help= "The width of the realsense image",
    )
    parser.add_argument("--fps",
        type= int,
        default= 30,
        help= "The fps of the realsense image",
    )
    parser.add_argument("--crop_left",
        type= int,
        default= 60,
        help= "num of pixel to crop in the original pyrealsense readings."
    )
    parser.add_argument("--crop_right",
        type= int,
        default= 46,
        help= "num of pixel to crop in the original pyrealsense readings."
    )
    parser.add_argument("--crop_top",
        type= int,
        default= 0,
        help= "num of pixel to crop in the original pyrealsense readings."
    )
    parser.add_argument("--crop_bottom",
        type= int,
        default= 0,
        help= "num of pixel to crop in the original pyrealsense readings."
    )
    parser.add_argument("--crop_far",
        type= float,
        default= 3.0,
        help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
    )
    parser.add_argument("--enable_rgb",
        action= "store_true",
        help= "Whether to enable rgb image",
    )
    parser.add_argument("--enable_vis",
        action= "store_true",
        help= "Whether to publish realsense image",
    )

    args = parser.parse_args()
    main(args)
