"""
Adapters for open_spiel environments.

@Description:
- Essentials of `rl_environment.Environment.observation_spec`:
    * info_state: a list of observation tensors for current players (if `use_observation` is enabled)
    * legal_actions: a list of list of legal actions for current players
    * current_player: a list of player ids (int)
    * serialized_state: serialized global state (if `_include_full_state` is enabled)

- Essentials of `rl_environment.Environment.action_spec`:
    * num_actions: int, indicates the number of distinct actions
    * min: float/int, the lower bound of the action range
    * max: float/int, the upper bound of the action range
    * dtype: indicates the data type 

@Functions:
    1. convert dict description (observation_spec, action_spec) to gym.Space
    2. convert raw environment observation and action with adapters to match the space definitions
"""

import gym
import numpy as np
import pyspiel

from gym import spaces

from expground import types
from expground.logger import Log


def ObservationSpace(observation_spec: types.Dict, **kwargs) -> spaces.Dict:
    """Analyzes accepted observation spec and returns a truncated observation space.

    Args:
        observation_spec (Dict): The raw obsevation spec in dict.

    Returns:
        gym.spaces.Dict: The truncated observation space in Dict.
    """

    _spaces = {}

    if len(observation_spec["info_state"]) > 0:
        _spaces["info_state"] = spaces.Box(
            low=-np.inf, high=np.inf, shape=observation_spec["info_state"]
        )
    if len(observation_spec["legal_actions"]) > 0:
        _spaces["action_mask"] = spaces.Box(
            low=0.0, high=1.0, shape=observation_spec["legal_actions"]
        )
    if len(observation_spec["serialized_state"]) > 0:
        _spaces["serialize_state"] = spaces.Box(
            low=-np.inf, high=np.inf, shape=observation_spec["serialize_state"]
        )

    Log.debug(f"Truncted Observation space: {_spaces}")
    return spaces.Dict(_spaces)


def ActionSpace(action_spec: types.Dict) -> gym.Space:
    """Analyzes accepted action spec and returns a truncated action space.

    Args:
        action_spec (types.Dict): The raw action spec in dict.

    Returns:
        gym.Space: The trucated action space.
    """

    if action_spec["dtype"] == float:
        return spaces.Box(
            low=action_spec["min"],
            high=action_spec["max"],
            shape=(action_spec["num_actions"]),
        )
    elif action_spec["dtype"] == int:
        return spaces.Discrete(action_spec["num_actions"])
    else:
        raise TypeError(
            f"Data type for action space is not allowed, expected are `float` or `int`, but {action_spec['dtype']} received."
        )


def observation_adapter(
    raw_observation: types.Dict, observation_space: spaces.Dict, player_id=None
) -> types.Dict:
    res = {}
    if isinstance(raw_observation, types.Dict):
        player = raw_observation["current_player"]
        # if current player is illegal return an empty frame
        for k, space in observation_space.spaces.items():
            # assert space.contains(
            #     raw_observation[k][player]
            # ), f"Illegal obervation values: {k}={raw_observation[k][player]}, space={space}"
            if k == "action_mask":
                mask = np.zeros(observation_space[k].shape)
                if player >= 0:
                    legal_actions = raw_observation["legal_actions"][player]
                    mask[legal_actions] = 1.0
                res["action_mask"] = mask
            else:
                if player >= 0:
                    res[k] = raw_observation[k][player]
                else:
                    res[k] = observation_space[k].sample()
    elif isinstance(raw_observation, pyspiel.State):
        if isinstance(observation_space, spaces.Dict):
            for k, space in observation_space.spaces.items():
                if k == "info_state":
                    res[k] = raw_observation.information_state_tensor()
                elif k == "action_mask":
                    mask = np.zeros(observation_space["action_mask"].shape)
                    mask[raw_observation.legal_actions()] = 1.0
                    res["action_mask"] = mask
        else:
            res = np.asarray(raw_observation.information_state_tensor(player_id))
            # elif k == "serialize_state":
            #     res[k] = pyspiel.serialize_game_and_state(game, raw_observation)
    return res


def action_adapter(raw_action: types.Any, action_space: gym.Space):
    # if isinstance(raw_action, types.Sequence):
    #     # this could be a multi dimensional Box
    #     action = np.array(raw_action, dtype=action_space.dtype)
    #     assert action_space.contains(action), f"Illegal action values: {action}, space={action_space}"
    # elif type(raw_action) in [float, int]:
    #     # this could be a single dimensional Box or Discrete
    assert action_space.contains(
        raw_action
    ), f"Illegal action value: {raw_action}, space={action_space}"
    return raw_action
