"""Generic Webots Gym TCP/IP integration.

The protocol is described by the following example exchange with
a hypothetical `controller.py` (sorry for not providing a full specification):

webots_gym.py: opens socket
controller.py: connects to socket
controller.py: sends a dictionary with the following keys:
```
observation,
motor_n,
max_motor_position,
min_motor_position,
```
webots_gym.py: completes `__init__`
webots_gym.py: sends `"RESET"`
controller.py: calls `simulationReset()`
controller.py: sends a dictionary with the following key:
```
observation
```
webots_gym.py: sends a dictionary with the following key:
```
action
```
controller.py: sends a dictionary with the following key:
```
observation
```

Your controller can, however, send any additional data in the dictionary.
"""
import gymnasium as gym
from pathlib import Path
from gymnasium import spaces
import subprocess
import socket
import json
import numpy as np
from abc import ABC
from abc import abstractmethod
import random


class WebotsEnv(gym.Env, ABC):
    """This class interfaces with a Webots controller that
    sends observations and receives actions over TCP/IP.

    - src_directory: directory with both the controller and world files
    """
    metadata = {"render_modes": []}

    def __init__(
            self,
            wbt_path: Path,
            port: int,
            server_name: str,
            render_mode=None,
            ):
        assert render_mode is None, "Rendering with Gym API not supported!"
        self.server_name = server_name
        self.port = port
        self.wbt_path = wbt_path
        self.connection = None
        self.webots_process = None
        self.server_socket = None

        # It doesn't matter which port webots listens in because we run
        # our own TCP server
        self._controller_port = random.randint(10_000, 65_535)
        while self._controller_port == self.port:
            self._controller_port = random.randint(10_000, 65_535)

        self._restart_server()

        min_action = self.latest_data["min_motor_position"]
        max_action = self.latest_data["max_motor_position"]
        observation_size = len(self.latest_data["observation"])
        self.action_space = spaces.Box(
            np.array(min_action),
            np.array(max_action),
            dtype=np.float64,
        )
        self.observation_space = spaces.Box(
            -np.inf,
            np.inf,
            shape=(observation_size,),
            dtype=np.float64,
        )

    def _restart_server(self):
        if self.connection is not None:
            self.connection.sendall("CLOSE".encode('utf-8'))
            try:
                self.webots_process.wait(timeout=2)
            except subprocess.TimeoutExpired:
                pass
            self.connection.close()
        if self.webots_process is not None:
            self.webots_process.terminate()
            self.webots_process.wait()
        if self.server_socket is not None:
            self.server_socket.close()

        # Start TCP server
        server_address = (self.server_name, self.port)
        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.server_socket.bind(server_address)
        self.server_socket.listen(1)

        # Open world file with Webots
        command = [
            "xvfb-run",
            "--auto-servernum",
            "webots",
            "--stdout",
            "--stderr",
            "--batch",
            "--no-rendering",
            "--mode=fast",
            f"--port={self._controller_port}",
            str(self.wbt_path),
        ]
        self.webots_process = subprocess.Popen(
            command,
        )

        # Wait for the controller to connect
        print("Gym wrapper is waiting for a connection")
        connection, _ = self.server_socket.accept()
        self.connection = connection
        self.maxbufsize = 1024  # TODO: parametrize

        # Receive data
        received_data = self.connection.recv(self.maxbufsize).decode('utf-8')
        self.received_data = [json.loads(received_data)]
        self.sent_data = list()

    @property
    @abstractmethod
    def latest_reward(self) -> float:
        """Return the current reward. To implement this method you could make
        use of `self.received_data` and `self.sent_data`, which are
        automatically populated.
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def terminated(self) -> bool:
        """Return the current terminated status. To implement this method you
        could make use of `self.received_data` and `self.sent_data`, which
        are automatically populated.
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def truncated(self) -> bool:
        """Return the current terminated status. To implement this method you
        could make use of `self.received_data` and `self.sent_data`, which
        are automatically populated.
        """
        raise NotImplementedError()

    @property
    def latest_data(self):
        return self.received_data[-1]

    @property
    def observation(self) -> np.ndarray:
        current_observation = self.latest_data["observation"]
        return np.array(current_observation)

    @property
    def info(self):
        info = dict(
            state=self.observation,
            **self.latest_data["info"],
        )
        return info

    def reset(self, seed=None, options=None):
        self._restart_server()
        return self.observation, self.info

    def step(self, action: np.ndarray):
        # Send action to controller
        status = dict(
            action=action.tolist(),
        )
        status_str = json.dumps(status)

        server_died = False
        try:
            self.connection.sendall(status_str.encode('utf-8'))
            self.sent_data.append(status)
        except InterruptedError:
            server_died = True

        # Receive new data from controller
        try:
            received_data = self.connection.recv(self.maxbufsize).decode('utf-8')
            self.received_data.append(json.loads(received_data))
        except ConnectionResetError:
            server_died = True

        # Compute reward
        reward = self.latest_reward
        terminated = self.terminated
        truncated = self.truncated or server_died
        info = self.info
        return self.observation, reward, terminated, truncated, info

    def close(self):
        if self.connection is not None:
            self.connection.sendall("CLOSE".encode('utf-8'))
        if self.webots_process is not None:
            self.webots_process.terminate()
            self.webots_process.wait()
        if self.server_socket is not None:
            self.server_socket.close()
