from agentlace.action import ActionConfig
from agentlace.trainer import TrainerConfig
import tensorflow as tf
import numpy as np


def make_trainer_config():
    return TrainerConfig(
        port_number=5488,
        broadcast_port=5489,
        request_types=["send-stats", "get-model-config"],
    )

def observation_format():
    return    {
        # Raw sensor
        "image": tf.TensorSpec((), tf.string, name="image"),
        "imu_accel": tf.TensorSpec((3,), tf.float32, name="imu_accel"),
        "imu_gyro": tf.TensorSpec((3,), tf.float32, name="imu_gyro"),
        "odom_pose": tf.TensorSpec((3,), tf.float32, name="odom_pose"),
        "linear_velocity": tf.TensorSpec((3,), tf.float32, name="linear_velocity"),
        "angular_velocity": tf.TensorSpec((3,), tf.float32, name="angular_velocity"),

        # Hazards from IRobot
        "cliff": tf.TensorSpec((), tf.bool, name="cliff"),
        "crash": tf.TensorSpec((), tf.bool, name="crash"),
        "crash_left": tf.TensorSpec((), tf.bool, name="crash_left"),
        "crash_right": tf.TensorSpec((), tf.bool, name="crash_right"),
        "crash_center": tf.TensorSpec((), tf.bool, name="crash_center"),
        "stall": tf.TensorSpec((), tf.bool, name="stall"),
        "keepout": tf.TensorSpec((), tf.bool, name="keepout"),

        # Estimator
        "position": tf.TensorSpec((3,), tf.float32, name="position"),
        "orientation": tf.TensorSpec((4,), tf.float32, name="orientation"),
        "pose_std": tf.TensorSpec((6,), tf.float32, name="pose_std"),

        # State machine and action
        "action_state_source": tf.TensorSpec((), tf.string, name="action_state_source"),
        "last_action_linear": tf.TensorSpec((3,), tf.float32, name="last_action_linear"),
        "last_action_angular": tf.TensorSpec((3,), tf.float32, name="last_action_angular"), "last_action_angular": tf.TensorSpec((3,), tf.float32, name="last_action_angular"),
    }

def robot_data_format():
    return {
        "observation": observation_format(),
    }

def rlds_data_format():
    obs_format = observation_format()
    # del obs_format["action_state_source"] 
    return {
        "observation": obs_format,
        "is_first": tf.TensorSpec((), tf.bool, name="is_first"),
        "is_last": tf.TensorSpec((), tf.bool, name="is_last"),
        "is_terminal": tf.TensorSpec((), tf.bool, name="is_terminal"),
}

def task_data_format():
    return {
        # **robot_data_format(),
        "observation": {
            **observation_format(),
            "goal": {
                "image": tf.TensorSpec((), tf.string, name="image"),
                "position": tf.TensorSpec((3,), tf.float32, name="position"),
                "orientation": tf.TensorSpec((4,), tf.float32, name="orientation"),
                "reached": tf.TensorSpec((), tf.bool, name="reached"),
                "sample_info": {
                    "position": tf.TensorSpec((3,), tf.float32, name="position"),
                    "orientation": tf.TensorSpec((4,), tf.float32, name="orientation"),
                    "offset": tf.TensorSpec((), tf.float32, name="offset"),
                }
            },
        },
        "action": tf.TensorSpec((6,), tf.float32, name="action"),
        "is_first": tf.TensorSpec((), tf.bool, name="is_first"),
        "is_last": tf.TensorSpec((), tf.bool, name="is_last"),
        "is_terminal": tf.TensorSpec((), tf.bool, name="is_terminal"),
    }

def observation_keys():
    keys = list(observation_format().keys())
    # keys.append("latest_action") # so action is ONLY in an observation for robot server / client but not saved 
    return keys 

def make_action_config():
    return ActionConfig(
        port_number=1111,
        action_keys=["action_vw", "action_pose", "reset", "dock", "undock", "new_goal", "q_vals"],
        observation_keys=list(observation_format().keys()),
    )
    # return ActionConfig(
    #     port_number=5486,
    #     action_keys=["action_vw", "reset", "move_marker", "dock", "undock"],
    #     observation_keys=observation_keys(),
    #     broadcast_port=5487,
    # )