import gym
import networkx as nx
import numpy as np
from gym import ObservationWrapper
from gym.spaces import Box
from karateclub.graph_embedding.feathergraph import FeatherGraph

from yawning_titan.envs.generic.generic_env import GenericNetworkEnv


class FeatherGraphEmbedObservation(ObservationWrapper):
    """
    Gym Observation Space Wrapper that embeds the underlying environment graph using the Feather-G algorithm.

    This wrapper uses the Feather-G Whole Graph embedding algorithm to embed the underlying environment
    graph and then re-creates the observation space to include the embedding and all other
    observation space settings from the configuration file.
    """

    def __init__(self, env: GenericNetworkEnv, max_num_nodes: int = 100):
        """
        Initialise a Feather-G observation space wrapper.

        Args:
            env: the OpenAI Gym environment to be wrapped
            max_num_nodes: the maximum number of nodes required to be supported in the
                           observation space

        Note:
            The max_num_nodes is for defining the maximum number of nodes you want
            the agent to support within its observation space. This is in
            order to support the Training of agents which can work across a number of
            YAWNING TITAN environments with variable node counts.

            For example, if set to 100 (like the default), the agent could be trained in
            an environment with 10 nodes, 50 nodes or 100 nodes.
        """
        super(FeatherGraphEmbedObservation, self).__init__(env)
        self.env: GenericNetworkEnv = env
        self.network_interface = env.network_interface
        self.new_ob_space_dim = env.calculate_observation_space_size(with_feather=True)
        self.original_observation_space: gym.spaces.Box = env.observation_space
        self.observation_space: gym.spaces.Box = Box(
            -np.inf, np.inf, shape=(self.new_ob_space_dim,)
        )
        self.latest_adj_matrix = None
        self.latest_graph_embedding = None

    def observation(self, observation: np.ndarray) -> np.ndarray:
        """
        Observation Transformation Function.

        1. Generates a networkx graph object from the current adjacency matrix
        2. Collects the current vulnerability scores and node status's
        3. Pads the returned arrays to ensure length is 100 (currently arbitrarily set)
        4. Embeds the networkx graph using the Feather Graph algorithm from Karateclub
        5. Concatenates the graph embedding, padded vulnerability scores and padded node status's together
        6. Returns new observation

        Args:
            observation: The base, unwrapped observation generated by the environment

        Returns:
            A newly formatted environment observation
        """
        if self.latest_adj_matrix is None:
            self.latest_adj_matrix = self.env.network_interface.adj_matrix
            self.latest_graph_embedding = self.make_embedding()

        elif (
            self.env.network_interface.adj_matrix.all() != self.latest_adj_matrix.all()
        ):
            self.latest_adj_matrix = self.env.network_interface.adj_matrix
            self.latest_graph_embedding = self.make_embedding()

        standard_obs = self.env.network_interface.get_current_observation()
        if self.network_interface.game_mode.observation_space.node_connections.value:
            size_standard_adj = self.network_interface.get_total_num_nodes() ** 2

            extra_obs = standard_obs[size_standard_adj:]
            observation = np.concatenate(
                (self.latest_graph_embedding, extra_obs), axis=None, dtype=np.float32
            )
        else:
            observation = standard_obs

        return observation

    def make_embedding(self) -> np.ndarray:
        """
        Create a FeaterGraph embedding from the inputted NetworkX graph.

        Returns:
            A numpy array containing the Feather embedding
        """
        current_graph = nx.from_numpy_array(self.latest_adj_matrix)
        embedder = FeatherGraph()
        embedder.fit([current_graph])
        embedding = embedder.get_embedding()

        return embedding
