"""Adapted from salamander.c

https://github.com/cyberbotics/webots/blob/e3a5f1f4ed0deee0e0d51969e1e8b3e6a33a2ad9/projects/samples/rendering/controllers/salamander/salamander.c#L4

At each controller step, the following data structure is print out to
stdout:

```
{
    observation: [float values],
    mode: ("WALK"|"SWIM"),
}
```
"""
from controller import Supervisor
import os
import json
from pathlib import Path
import socket

# locomotion types
WALK = 0
SWIM = 1

X = 0
Y = 1
Z = 2


# Set this variable to whatever directory you want the animation
# to be saved.
output_dir = ""
if output_dir is not None:
    output_dir = Path(output_dir)


# Server parameters
port = ""
server_name = ""

MAX_STEPS = 2000


def clamp(value: float, minv: float, maxv: float) -> float:
    if value < minv:
        return minv
    elif value > maxv:
        return maxv
    return value


def main():
    # Create a temporary directory
    if output_dir is not None:
        animation_path = output_dir/"animation.html"
    else:
        animation_path = None

    # must be the same as in salamander_physics.c
    WATER_LEVEL = 0.0

    # 6 actuated body segments and 4 legs
    NUM_MOTORS = 10

    # virtual time between two calls to the run() function
    CONTROL_STEP = 32

    if animation_path is not None:
        robot.animationStartRecording(str(animation_path))

    # body and leg motors
    target_position = [0.0 for _ in range(NUM_MOTORS)]

    # get the motors device tags
    MOTOR_NAMES = [
        "motor_1",
        "motor_2",
        "motor_3",
        "motor_4",
        "motor_5",
        "motor_6",
        "motor_leg_1",
        "motor_leg_2",
        "motor_leg_3",
        "motor_leg_4",
    ]
    motor = [
        robot.getDevice(motor_name)
        for motor_name in MOTOR_NAMES
    ]
    min_motor_position = [
        motor_i.getMinPosition()
        for motor_i in motor
    ]
    max_motor_position = [
        motor_i.getMaxPosition()
        for motor_i in motor
    ]

    # get and enable left and right distance sensors
    ds_left = robot.getDevice("ds_left")
    ds_left.enable(CONTROL_STEP)
    ds_right = robot.getDevice("ds_right")
    ds_right.enable(CONTROL_STEP)

    # get and enable gps device
    gps = robot.getDevice("gps")
    gps.enable(CONTROL_STEP)

    locomotion = WALK

    # Helper function to get the robots observation
    def get_observation():
        # This is not the actual motor position, but it's close enough
        # for this robot
        motor_positions = target_position
        gps_values = gps.getValues()
        left_val = ds_left.getValue()
        right_val = ds_right.getValue()
        observation = [
            *motor_positions,
            left_val,
            right_val,
            #gps_values[Z]
        ]
        return observation

    # Send observation and mode
    # Use approximate minimum and maximum values
    robot.step(CONTROL_STEP)  # synchronize data for the first time
    position = gps.getValues()
    info = dict(
        ground_truth_mode=locomotion,
    )
    step_i = 0
    status = dict(
        observation=get_observation(),
        motor_n=len(motor),
        max_motor_position=[5.0 for _ in max_motor_position],
        min_motor_position=[-5.0 for _ in min_motor_position],
        position=position,
        step_i=step_i,
        info=info,
    )
    status_str = json.dumps(status)
    client_socket.sendall(status_str.encode('utf-8'))

    # control loop: sense-compute-act
    while robot.step(CONTROL_STEP) != -1:

        step_i += 1

        # Receive data from the server
        data = client_socket.recv(1024).decode('utf-8')
        if not data:
            break
        elif data == "CLOSE":
            break
        elif data == "RESET":
            robot.simulationReset()
        else:
            latest_data = json.loads(data)
            target_position = latest_data["action"]

        # Send data to server
        elevation = gps.getValues()[Z]
        if locomotion == SWIM and elevation > WATER_LEVEL - 0.003:
            locomotion = WALK
        elif locomotion == WALK and elevation < WATER_LEVEL - 0.015:
            locomotion = SWIM
        mode = locomotion
        position = gps.getValues()
        info = dict(
            ground_truth_mode=mode,
        )
        status = dict(
            observation=get_observation(),
            position=position,
            info=info,
            step_i=step_i,
        )
        status_str = json.dumps(status)
        client_socket.sendall(status_str.encode('utf-8'))

        # motors actuation
        for i in range(NUM_MOTORS):
            # Leg motors don't have max and min positions, and the
            # value should not be clamped. This magic code checks for that
            # case.
            minv = min_motor_position[i]
            maxv = max_motor_position[i]
            if not (minv == 0.0 and maxv == 0.0):
                target_position[i] = clamp(
                    target_position[i],
                    min_motor_position[i],
                    max_motor_position[i],
                )
            motor[i].setPosition(target_position[i])

    # Stop animation and save viewing instructions
    if animation_path is not None:
        robot.animationStopRecording()
        animation_instructions = (
            "To view this animation `cd` to this directory and then run `python3 -m http.server`. "
            "Then, open `http://0.0.0.0:8000/animation.html` in your browser."
        )
        animation_instructions_path = output_dir/"README.md"
        with open(animation_instructions_path, "wt") as fp:
            fp.write(animation_instructions)


robot = Supervisor()
# Start server
server_address = (server_name, port)
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_socket.connect(server_address)
try:
    main()
finally:
    robot.simulationQuit(os.EX_OK)
    print("Webots controller: Closing")
    client_socket.close()
