""" The main Environment class for NASim: NASimEnv.

The NASimEnv class is the main interface for agents interacting with NASim.
"""
import gymnasium as gym
from gymnasium import spaces
import numpy as np

from nasim.envs.state import State
from nasim.envs.render import Viewer
from nasim.envs.network import Network
from nasim.envs.observation import Observation
from nasim.envs.action import Action, FlatActionSpace, ParameterisedActionSpace


class NASimEnv(gym.Env):
    """ A simulated computer network environment for pen-testing.

    Implements the gymnasium interface.

    ...

    Attributes
    ----------
    name : str
        the environment scenario name
    scenario : Scenario
        Scenario object, defining the properties of the environment
    action_space : FlatActionSpace or ParameterisedActionSpace
        Action space for environment.
        If *flat_action=True* then this is a discrete action space (which
        subclasses gymnasium.spaces.Discrete), so each action is represented by an
        integer.
        If *flat_action=False* then this is a parameterised action space (which
        subclasses gymnasium.spaces.MultiDiscrete), so each action is represented
        using a list of parameters.
    observation_space : gymnasium.spaces.Box
        observation space for environment.
        If *flat_obs=True* then observations are represented by a 1D vector,
        otherwise observations are represented as a 2D matrix.
    current_state : State
        the current state of the environment
    last_obs : Observation
        the last observation that was generated by environment
    steps : int
        the number of steps performed since last reset (this does not include
        generative steps)

    """
    metadata = {'render_modes': ["human", "ansi"]}
    render_mode = None
    reward_range = (-float('inf'), float('inf'))

    action_space = None
    observation_space = None
    current_state = None
    last_obs = None

    def __init__(self,
                 scenario,
                 fully_obs=False,
                 flat_actions=True,
                 flat_obs=True,
                 render_mode=None):
        """
        Parameters
        ----------
        scenario : Scenario
            Scenario object, defining the properties of the environment
        fully_obs : bool, optional
            The observability mode of environment, if True then uses fully
            observable mode, otherwise is partially observable (default=False)
        flat_actions : bool, optional
            If true then uses a flat action space, otherwise will uses a
            parameterised action space (default=True).
        flat_obs : bool, optional
            If true then uses a 1D observation space, otherwise uses a 2D
            observation space (default=True)
        render_mode : str, optional
            The render mode to use for the environment.
        """
        self.name = scenario.name
        self.scenario = scenario
        self.fully_obs = fully_obs
        self.flat_actions = flat_actions
        self.flat_obs = flat_obs
        self.render_mode = render_mode

        self.network = Network(scenario)
        self.current_state = State.generate_initial_state(self.network)
        self._renderer = None
        self.reset()

        if self.flat_actions:
            self.action_space = FlatActionSpace(self.scenario)
        else:
            self.action_space = ParameterisedActionSpace(self.scenario)

        if self.flat_obs:
            obs_shape = self.last_obs.shape_flat()
        else:
            obs_shape = self.last_obs.shape()
        obs_low, obs_high = Observation.get_space_bounds(self.scenario)
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, shape=obs_shape
        )

        self.steps = 0

    def seed(self, seed):
        np.random.seed(seed)

    def reset(self, *, seed=None, options=None):
        """Reset the state of the environment and returns the initial state.

        Implements gymnasium.Env.reset().

        Parameters
        ----------
        seed : int, optional
            the optional seed for the environments RNG
        options : dict, optional
            optional environment options (does nothing in NASim at the moment)

        Returns
        -------
        numpy.Array
            the initial observation of the environment
        dict
            auxiliary information regarding reset
        """
        super().reset(seed=seed, options=options)
        self.steps = 0
        self.current_state = self.network.reset(self.current_state)
        self.last_obs = self.current_state.get_initial_observation(
            self.fully_obs
        )

        if self.flat_obs:
            obs = self.last_obs.numpy_flat()
        else:
            obs = self.last_obs.numpy()

        return obs, {}

    def step(self, action):
        """Run one step of the environment using action.

        Implements gymnasium.Env.step().

        Parameters
        ----------
        action : Action or int or list or NumpyArray
            Action to perform. If not Action object, then if using
            flat actions this should be an int and if using non-flat actions
            this should be an indexable array.

        Returns
        -------
        numpy.Array
            observation from performing action
        float
            reward from performing action
        bool
            whether the episode reached a terminal state or not (i.e. all
            target machines have been successfully compromised)
        bool
            whether the episode has reached the step limit (if one exists)
        dict
            auxiliary information regarding step
            (see :func:`nasim.env.action.ActionResult.info`)
        """
        next_state, obs, reward, done, info = self.generative_step(
            self.current_state,
            action
        )
        self.current_state = next_state
        self.last_obs = obs

        if self.flat_obs:
            obs = obs.numpy_flat()
        else:
            obs = obs.numpy()

        self.steps += 1

        step_limit_reached = (
            self.scenario.step_limit is not None
            and self.steps >= self.scenario.step_limit
        )

        return obs, reward, done, step_limit_reached, info

    def generative_step(self, state, action):
        """Run one step of the environment using action in given state.

        Parameters
        ----------
        state : State
            The state to perform the action in
        action : Action, int, list, NumpyArray
            Action to perform. If not Action object, then if using
            flat actions this should be an int and if using non-flat actions
            this should be an indexable array.

        Returns
        -------
        State
            the next state after action was performed
        Observation
            observation from performing action
        float
            reward from performing action
        bool
            whether a terminal state has been reached or not
        dict
            auxiliary information regarding step
            (see :func:`nasim.env.action.ActionResult.info`)
        """
        if not isinstance(action, Action):
            action = self.action_space.get_action(action)

        next_state, action_obs = self.network.perform_action(
            state, action
        )
        obs = next_state.get_observation(
            action, action_obs, self.fully_obs
        )
        done = self.goal_reached(next_state)
        reward = action_obs.value - action.cost
        return next_state, obs, reward, done, action_obs.info()

    def generate_random_initial_state(self):
        """Generates a random initial state for environment.

        This only randomizes the host configurations (os, services)
        using a uniform distribution, so may result in networks where
        it is not possible to reach the goal.

        Returns
        -------
        State
            A random initial state
        """
        return State.generate_random_initial_state(self.network)

    def generate_initial_state(self):
        """Generate the initial state for the environment.

        Returns
        -------
        State
            The initial state

        Notes
        -----
        This does not reset the current state of the environment (use
        :func:`reset` for that).
        """
        return State.generate_initial_state(self.network)

    def render(self):
        """Render environment.

        Implements gymnasium.Env.render().

        See render module for more details on modes and symbols.

        """
        if self.render_mode is None:
            return
        return self.render_obs(mode=self.render_mode, obs=self.last_obs)

    def render_obs(self, mode="human", obs=None):
        """Render observation.

        See render module for more details on modes and symbols.

        Parameters
        ----------
        mode : str
            rendering mode
        obs : Observation or numpy.ndarray, optional
            the observation to render, if None will render last observation.
            If numpy.ndarray it must be in format that matches Observation
            (i.e. ndarray returned by step method) (default=None)
        """
        if mode is None:
            return

        if obs is None:
            obs = self.last_obs

        if not isinstance(obs, Observation):
            obs = Observation.from_numpy(obs, self.current_state.shape())

        if self._renderer is None:
            self._renderer = Viewer(self.network)

        if mode in ("human", "ansi"):
            return self._renderer.render_readable(obs)
        else:
            raise NotImplementedError(
                "Please choose correct render mode from :"
                f"{self.metadata['render_modes']}"
            )

    def render_state(self, mode="human", state=None):
        """Render state.

        See render module for more details on modes and symbols.

        If mode = ASCI:
            Machines displayed in rows, with one row for each subnet and
            hosts displayed in order of id within subnet

        Parameters
        ----------
        mode : str
            rendering mode
        state : State or numpy.ndarray, optional
            the State to render, if None will render current state
            If numpy.ndarray it must be in format that matches State
            (i.e. ndarray returned by generative_step method) (default=None)
        """
        if mode is None:
            return

        if state is None:
            state = self.current_state

        if not isinstance(state, State):
            state = State.from_numpy(state,
                                     self.current_state.shape(),
                                     self.current_state.host_num_map)

        if self._renderer is None:
            self._renderer = Viewer(self.network)

        if mode in ("human", "ansi"):
            return self._renderer.render_readable_state(state)
        else:
            raise NotImplementedError(
                "Please choose correct render mode from : "
                f"{self.metadata['render_modes']}"
            )

    def render_action(self, action):
        """Renders human readable version of action.

        This is mainly useful for getting a text description of the action
        that corresponds to a given integer.

        Parameters
        ----------
        action : Action or int or list or NumpyArray
            Action to render. If not Action object, then if using
            flat actions this should be an int and if using non-flat actions
            this should be an indexable array.
        """
        if not isinstance(action, Action):
            action = self.action_space.get_action(action)
        print(action)

    def render_episode(self, episode, width=7, height=7):
        """Render an episode as sequence of network graphs, where an episode
        is a sequence of (state, action, reward, done) tuples generated from
        interactions with environment.

        Parameters
        ----------
        episode : list
            list of (State, Action, reward, done) tuples
        width : int
            width of GUI window
        height : int
            height of GUI window
        """
        if self._renderer is None:
            self._renderer = Viewer(self.network)
        self._renderer.render_episode(episode, width, height)

    def render_network_graph(self, ax=None, show=False):
        """Render a plot of network as a graph with hosts as nodes arranged
        into subnets and showing connections between subnets. Renders current
        state of network.

        Parameters
        ----------
        ax : Axes
            matplotlib axis to plot graph on, or None to plot on new axis
        show : bool
            whether to display plot, or simply setup plot and showing plot
            can be handled elsewhere by user
        """
        if self._renderer is None:
            self._renderer = Viewer(self.network)
        state = self.current_state
        self._renderer.render_graph(state, ax, show)

    def get_minimum_hops(self):
        """Get the minimum number of network hops required to reach targets.

        That is minimum number of hosts that must be traversed in the network
        in order to reach all sensitive hosts on the network starting from the
        initial state

        Returns
        -------
        int
            minumum possible number of network hops to reach target hosts
        """
        return self.network.get_minimal_hops()

    def get_action_mask(self):
        """Get a vector mask for valid actions. The mask is based on whether
        a host has been discovered or not.

        Returns
        -------
        ndarray
            numpy vector of 1's and 0's, one for each action. Where an
            index will be 1 if action is valid given current state, or
            0 if action is invalid.
        """
        assert isinstance(self.action_space, FlatActionSpace), \
            "Can only use action mask function when using flat action space"
        mask = np.zeros(self.action_space.n, dtype=np.int64)
        # Note: There certainly is a more efficient way of doing this.
        for a_idx in range(self.action_space.n):
            action = self.action_space.get_action(a_idx)
            if self.current_state.get_host(action.target).discovered:
                mask[a_idx] = 1
        return mask
    
    def action_masks(self):
        """Get a vector mask for valid actions. The mask is based on whether
        a host has been discovered or not.

        Returns
        -------
        ndarray
            numpy vector of 1's and 0's, one for each action. Where an
            index will be 1 if action is valid given current state, or
            0 if action is invalid.
        """
        assert isinstance(self.action_space, FlatActionSpace), \
            "Can only use action mask function when using flat action space"

        # Create a list of bools telling us if host i has been discovered
        discovered = [h[1].discovered for h in self.current_state.hosts]
        num_actions_per_host = self.action_space.n / len(discovered)

        assert self.action_space.n / num_actions_per_host == len(discovered), \
            "Hosts don't all have the same amout of actions"

        # Repeat the bool num_actions_per_host times
        mask = np.repeat(discovered, num_actions_per_host)

        return mask


    def get_score_upper_bound(self):
        """Get the theoretical upper bound for total reward for scenario.

        The theoretical upper bound score is where the agent exploits only a
        single host in each subnet that is required to reach sensitive hosts
        along the shortest bath in network graph, and exploits the all
        sensitive hosts (i.e. the minimum network hops). Assuming action cost
        of 1 and each sensitive host is exploitable from any other connected
        subnet (which may not be true, hence being an upper bound).

        Returns
        -------
        float
            theoretical max score
        """
        max_reward = self.network.get_total_sensitive_host_value()
        max_reward += self.network.get_total_discovery_value()
        max_reward -= self.network.get_minimal_hops()
        return max_reward

    def goal_reached(self, state=None):
        """Check if the state is the goal state.

        The goal state is when all sensitive hosts have been compromised.

        Parameters
        ----------
        state : State, optional
            a state, if None will use current_state of environment
            (default=None)

        Returns
        -------
        bool
            True if state is goal state, otherwise False.
        """
        if state is None:
            state = self.current_state
        return self.network.all_sensitive_hosts_compromised(state)

    def __str__(self):
        output = [
            "NASimEnv:",
            f"name={self.name}",
            f"fully_obs={self.fully_obs}",
            f"flat_actions={self.flat_actions}",
            f"flat_obs={self.flat_obs}"
        ]
        return "\n  ".join(output)

    def close(self):
        if self._renderer is not None:
            self._renderer.close()
            self._renderer = None
