import os
from typing import Optional, Any

from .mt_channel import MtChannel
from .minetest import Minetest

import numpy as np
from gymnasium import Env
from gymnasium.spaces import Dict, Discrete, Box

# names of the actions in the order they must be sent to MT
ACTION_ORDER = [
    "forward", "backward", "left", "right", "jump", "aux1", "sneak",
    "zoom", "dig", "place", "drop", "inventory", "slot_1", "slot_2",
    "slot_3", "slot_4", "slot_5", "slot_6", "slot_7", "slot_8", "slot_9",
    "mouse"
]


class CraftiumEnv(Env):
    """The main class implementing Gymnasium's [Env](https://gymnasium.farama.org/api/env/) API.

    :param env_dir: Directory of the environment to load (should contain `worlds` and `games` directories).
    :param obs_width: The width of the observation image in pixels.
    :param obs_height: The height of the observation image in pixels.
    :param init_frames: The number of frames to wait for Minetest to load.
    :param render_mode: Render mode ("human" or "rgb_array"), see [Env.render](https://gymnasium.farama.org/api/env/#gymnasium.Env.render).
    :param max_timesteps: Maximum number of timesteps until episode termination. Disabled if set to `None`.
    :param run_dir: Path to save the artifacts created by the run. Will be automatically generated if not provided.
    :param run_dir_prefix: Prefix path to add to the automatically generated `run_dir`. This value is only used if `run_dir` is `None`.
    :param game_id: The name of the game to load. Defaults to the "original" minetest game.
    :param world_name: The name of the world to load. Defaults to "world".
    :param minetest_dir: Path to the craftium's minetest build directory. If not given, defaults to the directory where craftium is installed. This option is intended for debugging purposes.
    :param tcp_port: Port number used to communicate with minetest. If not provided, the OS chooses a free TCP port automatically.
    :param minetest_conf: Extra configuration options added to the default minetest.conf file generated by craftium. Setting options here will overwrite default values. Check [mintest.conf.example](https://github.com/minetest/minetest/blob/master/minetest.conf.example) for all available configuration options.
    :param pipe_proc: If `True`, the minetest process stderr and stdout will be piped into two files inside the run's directory. Otherwise, the minetest process will not be piped and its output will be shown in the terminal. This option is disabled by default to reduce verbosity, but can be useful for debugging.
    :param mt_listen_timeout: Number of milliseconds to wait for MT to connect to the TCP channel. If the timeout is reached a Timeout exception is raised. **WARNING:** When using multiple (serial) MT environments, timeout can be easily reached for the last environment. In this case, you might want to increase the value of this parameter according to the number of environments.
    :param mt_port: TCP port to employ for MT's internal client<->server communication. If not provided a random port in the [49152, 65535] range is used.
    :param frameskip: The number of frames skipped between steps, 1 by default (disabled). Note that `max_timesteps` and `init_frames` parameters will be divided by the frameskip value.
    :param rgb_observations: Whether to use RGB images or gray scale images as observations. Note that RGB images are slower to send from MT to python via TCP. By default RGB images are used.
    :param gray_scale_keepdim: If `True`, a singleton dimension will be added, i.e. observations are of the shape WxHx1. Otherwise, they are of shape WxH.
    :param seed: Random seed. Affects minetest's map generation and Lua's RNG (in mods).
    """
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}

    def __init__(
            self,
            env_dir: os.PathLike,
            obs_width: int = 640,
            obs_height: int = 360,
            init_frames: int = 15,
            render_mode: Optional[str] = None,
            max_timesteps: Optional[int] = None,
            run_dir: Optional[os.PathLike] = None,
            run_dir_prefix: Optional[os.PathLike] = None,
            game_id: str = "minetest",
            world_name: str = "world",
            minetest_dir: Optional[str] = None,
            tcp_port: Optional[int] = None,
            minetest_conf: dict[str, Any] = dict(),
            pipe_proc: bool = True,
            mt_listen_timeout: int = 60_000,
            mt_port: Optional[int] = None,
            frameskip: int = 1,
            rgb_observations: bool = True,
            gray_scale_keepdim: bool = False,
            seed: Optional[int] = None,
    ):
        super(CraftiumEnv, self).__init__()

        self.obs_width = obs_width
        self.obs_height = obs_height
        self.init_frames = init_frames // frameskip
        self.max_timesteps = None if max_timesteps is None else max_timesteps // frameskip
        self.gray_scale_keepdim = gray_scale_keepdim
        self.rgb_observations = rgb_observations

        # define the action space
        action_dict = {}
        for act in ACTION_ORDER[:-1]:  # all actions except the last ("mouse")
            action_dict[act] = Discrete(2)  # 1/0: key pressed/not pressed
        # define the mouse action
        action_dict[ACTION_ORDER[-1]] = Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.action_space = Dict(action_dict)

        # define the observation space
        shape = [obs_width, obs_height]
        if rgb_observations:
            shape.append(3)
        elif gray_scale_keepdim:
            shape.append(1)

        self.observation_space = Box(low=0, high=255, shape=shape, dtype=np.uint8)

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        # initialize the Python<->Minetest communication channel (server side)
        self.mt_chann = MtChannel(
            img_width=self.obs_width,
            img_height=self.obs_height,
            port=tcp_port,
            listen_timeout=mt_listen_timeout,
            rgb_imgs=rgb_observations,
        )

        # handles the MT configuration and process
        self.mt = Minetest(
            world_name=world_name,
            run_dir=run_dir,
            run_dir_prefix=run_dir_prefix,
            headless=render_mode != "human",
            seed=seed,
            game_id=game_id,
            sync_dir=env_dir,
            screen_w=obs_width,
            screen_h=obs_height,
            minetest_dir=minetest_dir,
            tcp_port=self.mt_chann.port,
            minetest_conf=minetest_conf,
            pipe_proc=pipe_proc,
            mt_port=mt_port,
            frameskip=frameskip,
            rgb_frames=rgb_observations,
        )

        self.last_observation = None  # used in render if "rgb_array"
        self.timesteps = 0  # the timesteps counter

    def _get_info(self):
        return dict()

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ):
        """Resets the environment to an initial internal state, returning an initial observation and info.

        See [Env.reset](https://gymnasium.farama.org/api/env/#gymnasium.Env.reset) in the Gymnasium docs.

        :param seed: The random seed.
        :param options: Options dictionary.
        """
        super().reset(seed=seed)
        self.timesteps = 0

        if self.mt_chann.is_open():
            self.mt_chann.send_termination()
            self.mt_chann.close_conn()
            self.mt.close_pipes()
            self.mt.wait_close()

        # start the new MT process
        self.mt.start_process()

        # open communication channel with minetest
        try:
            self.mt_chann.open_conn()
        except Exception as e:
            print("\n\x1b[1m[!] Error connecting to Minetest. Minetest probably failed to launch.")
            print("  => Run's scratch directory should be available, containing stderr.txt and")
            print("     stdout.txt useful for checking what went wrong.")
            print("** Content of stderr.txt in the run's sratch directory:\x1b[0m")
            print("~"*45, "\n")
            with open(f"{self.mt.run_dir}/stderr.txt", "r") as f:
                print(f.read())
            print("~"*45)
            print("\x1b[1mRaising catched exception (in case it's useful):\x1b[0m")
            print("~"*45, "\n")
            raise e

        # HACK skip some frames to let the game initialize
        # TODO This "waiting" should be implemented in Minetest not in python
        for _ in range(self.init_frames):
            _observation, _reward, _term = self.mt_chann.receive()
            self.mt_chann.send([0]*21, 0, 0)  # nop action

        observation, _reward, _term = self.mt_chann.receive()
        if not self.gray_scale_keepdim and not self.rgb_observations:
            observation = observation[:, :, 0]

        self.last_observation = observation

        info = self._get_info()

        return observation, info

    def step(self, action):
        """Run one timestep of the environment’s dynamics using the agent actions.

        See [Env.step](https://gymnasium.farama.org/api/env/#gymnasium.Env.step) in the Gymnasium docs.

        :param action: An action provided by the agent.
        """
        self.timesteps += 1

        # convert the action dict to a format to be sent to MT through mt_chann
        keys = [0]*21  # all commands (keys) except the mouse
        mouse_x, mouse_y = 0, 0
        for k, v in action.items():
            if k == "mouse":
                x, y = v[0], -v[1]
                mouse_x = int(x*(self.obs_width // 2))
                mouse_y = int(y*(self.obs_height // 2))
            else:
                keys[ACTION_ORDER.index(k)] = v
        # send the action to MT
        self.mt_chann.send(keys, mouse_x, mouse_y)

        # receive the new info from minetest
        observation, reward, termination = self.mt_chann.receive()
        if not self.gray_scale_keepdim and not self.rgb_observations:
            observation = observation[:, :, 0]

        self.last_observation = observation

        info = self._get_info()

        truncated = self.max_timesteps is not None and self.timesteps >= self.max_timesteps

        return observation, reward, termination, truncated, info

    def render(self):
        if self.render_mode == "rgb_array":
            return self.last_observation

    def close(self):
        if self.mt_chann.is_open():
            self.mt_chann.send_termination()
            self.mt_chann.close()
            self.mt.close_pipes()
            self.mt.wait_close()
        self.mt.clear()
