from k_level_policy_gradients.src.core.multi_agent_core import MultiAgentCore
import numpy as np


class MultiAgentCoreHidden(MultiAgentCore):
    """
    Multi-agent core with hidden states for recurrent processing.
    Each agent network propagates a hidden state through time to be used
    in the next step to select actions
    """

    def _get_actions(self):
        actions = []
        for idx_agent in range(self.mdp.info.n_agents):
            if self.mdp.info.has_obs:
                if self.mdp.info.has_action_masks:
                    action, hidden_state = self.agents[idx_agent].draw_action_hidden(
                        self._obs[idx_agent],
                        self._hidden_states[idx_agent],
                        self._action_masks[idx_agent],
                    )
                else:
                    action, hidden_state = self.agents[idx_agent].draw_action_hidden(
                        self._obs[idx_agent],
                        self._hidden_states[idx_agent],
                    )
            else:
                action, hidden_state = self.agents[idx_agent].draw_action_hidden(
                    self._state,
                    self._hidden_states[idx_agent],
                )
            actions.append(action)
            self._hidden_states[idx_agent] = hidden_state
        return actions

    def reset(self):
        """
        Reset the state of the mdp and agents.

        """
        init_step = self.mdp.reset()

        self._state = init_step["state"]
        self._obs = init_step.get("obs", None)
        self._action_masks = init_step.get("action_masks", None)
        self._hidden_states = [None for _ in self.agents]

        if self.obs_last_action:
            for i, init_obs in enumerate(self._obs):
                if self.mdp.info.discrete_actions:  # make integer action one-hot
                    action = np.zeros(self.mdp.info.action_space[i].n)
                else:
                    action = np.zeros(self.mdp.info.action_space[i].shape)
                self._obs[i] = np.concatenate([init_obs, action])

        for agent in self.agents:
            agent.episode_start()
            agent.next_action = None
        self._episode_steps = 0
