"""Specialization of the Webots Gym environment class to the salamander robot.
"""
import gymnasium as gym
from pathlib import Path
from salamander_env.webots_gym import WebotsEnv
from salamander_env.map_parser import write_webots_project
from salamander_env.map_parser import get_random_map
from pathlib import Path
from pathlib import PurePath
import numpy as np
import random
import tempfile


relative_wbt_path = PurePath("world","salamandertcp.wbt")
relative_controller_path = PurePath("controllers", "salamandertcp", "salamandertcp.py")


def get_preprocessed_controller_file(
        port: int,
        server_name: str,
        controller_path: Path,
        animation_output_dir: Path | None,
        ) -> str:
    """Returns the controller file as a string with the given
    animation output directory inserted.

    The output of this function is to be saved in the same directory as
    the controller and opened with webots.
    """
    with open(controller_path, "rt") as fp:
        webots_file = fp.readlines()

        def assign_variable(variable: str, value: str):
            original_line = f'{variable} = ""\n'
            new_line = f'{variable} = {value}\n'
            preprocessed_controller_file = [
                line if line != original_line else new_line
                for line in webots_file
            ]
            return preprocessed_controller_file

        if animation_output_dir is None:
            output_dir_str = "None"
        else:
            output_dir_str = f'"{str(animation_output_dir.absolute())}"'
        changes = [
            ('output_dir', f'{output_dir_str}'),
            ('server_name', f'"{server_name}"'),
            ('port', f'{port}'),
        ]
        for variable, new_value in changes:
            webots_file = assign_variable(
                variable=variable,
                value=new_value,
            )
    return "".join(webots_file)


class SalamanderEnv(WebotsEnv):
    """An environment for a specific salamander map."""
    def __init__(
        self,
        animation_output_dir: (Path | None) = None,
        port: (int | None) = None,
        server_name: str = "localhost",
        render_mode: (str | None) = None,
        max_steps: int = 1000,
    ):
        self.episode_i = 0
        self.tmpdir = None
        self.animation_output_dir = animation_output_dir

        if port is None:
            port = random.randint(10_000, 65_535)
        self.port = port
        self.server_name = server_name

        self.max_steps = max_steps

        self.prepare_dir()

        super().__init__(
            wbt_path=self.wbt_path,
            render_mode=render_mode,
            port=port,
            server_name=server_name,
        )

    def prepare_dir(self):
        self.episode_i += 1

        if self.tmpdir is not None:
            self.tmpdir.cleanup()

        # Choose a map
        chosen_map = get_random_map(seed=str(random.random()))

        # Create webots directory
        self.tmpdir = tempfile.TemporaryDirectory()
        src_dir = Path(self.tmpdir.name)/"webots"
        write_webots_project(
            input_map=chosen_map,
            output_dir=src_dir,
        )

        # Get a preprocessed version of the controller file
        if self.animation_output_dir is not None:
            episode_animation_output_dir = self.animation_output_dir/f"{self.episode_i}"
            episode_animation_output_dir.mkdir()
        else:
            episode_animation_output_dir = None
        new_controller_file = get_preprocessed_controller_file(
            controller_path=src_dir/relative_controller_path,
            animation_output_dir=episode_animation_output_dir,
            port=self.port,
            server_name=self.server_name,
        )

        # Customize controller file
        with open(src_dir/relative_controller_path, "wt") as fp:
            fp.write(new_controller_file)

        self.wbt_path = src_dir/relative_wbt_path

    @property
    def observation(self) -> np.ndarray:
        current_observation = np.array(self.latest_data["observation"])
        if np.any(np.isnan(current_observation)):
            current_observation = np.zeros_like(current_observation)
        return current_observation

    @property
    def latest_reward(self) -> float:
        """Reward the agent for moving around."""
        if len(self.received_data) < 2:
            return 0.0

        # Integrate the position vector over the last window of time
        distance = 0.0
        window_size = min(len(self.received_data), 10)
        for i in range(1, window_size):
            xp = np.array(self.received_data[-i-1]["position"])
            xn = np.array(self.received_data[-i]["position"])
            distance += np.linalg.norm(xn-xp).item()
        return distance/window_size

    @property
    def truncated(self) -> bool:
        current_observation = self.latest_data["observation"]
        if np.any(np.isnan(current_observation)):
            return True
        if self.latest_data["step_i"] > self.max_steps:
            return True
        return False

    @property
    def terminated(self) -> bool:
        return False

    def reset(self, *args, **kwargs):
        self.prepare_dir()
        return super().reset(*args, **kwargs)

    def close(self):
        if self.tmpdir is not None:
            self.tmpdir.cleanup()
        super().close()


if __name__ == "__main__":
    import salamander_env
    import numpy as np
    animation_output_dir = Path("./salamander_animation")
    animation_output_dir.mkdir()
    env = gym.make(
        "Salamander-v0",
        animation_output_dir=animation_output_dir,
        port=65434,
        server_name='localhost',
    )
    observation, info = env.reset(seed=42)
    action = env.action_space.sample()  # this is where you would insert your policy
    for t in range(1000):
        if t % 100 == 0:
            action = env.action_space.sample()  # this is where you would insert your policy
        action = np.array([
            t
            for _ in action
        ])
        observation, reward, terminated, truncated, info = env.step(action)

        if terminated or truncated:
            break

    env.close()
    print(f"Wrote {animation_output_dir}")
