import numpy as np
from k_level_policy_gradients.src.core.multi_agent_core_hidden import (
    MultiAgentCoreHidden,
)


class MultiAgentCoreHiddenShared(MultiAgentCoreHidden):
    """
    Multi-agent core with hidden states for recurrent processing.
    Each agent network propagates a hidden state through time to be used
    in the next step to select actions
    """

    def _step(self, render, record):
        actions = self._get_actions()

        step = self.mdp.step(actions)

        next_state = step["state"]
        next_obs = step.get("obs", None)
        rewards = step["rewards"]
        absorbing = step["absorbing"]
        next_action_masks = step.get("action_masks", None)
        info = step.get("info", None)

        self._episode_steps += 1

        next_obs = self.encode_obs(next_obs)
        if self.obs_last_action:
            for i, step_obs in enumerate(next_obs):
                if self.mdp.info.discrete_actions:  # make integer action one-hot
                    action = np.zeros(self.mdp.info.action_space[i].n)
                    action[actions[i]] = 1
                else:
                    action = actions[i]
                next_obs[i] = np.concatenate([step_obs, action])

        if render:
            render_info = {}
            frame = self.mdp.render()

            if record:
                self._record(frame)

        last = self._episode_steps >= self.mdp.info.horizon or absorbing

        obs = self._obs
        next_obs = next_obs.copy()
        self._obs = next_obs

        state = self._state
        next_state = next_state.copy()
        self._state = next_state

        action_masks = self._action_masks.copy()
        self._action_masks = next_action_masks

        sample = {
            "state": state,
            "obs": obs,
            "action_masks": action_masks,
            "actions": actions,
            "rewards": rewards,
            "next_state": next_state,
            "next_obs": next_obs,
            "next_action_masks": next_action_masks,
            "absorbing": absorbing,
            "last": last,
        }

        return sample, info

    def reset(self):
        """
        Reset the state of the mdp and agents.

        """

        init_step = self.mdp.reset()

        self._state = init_step["state"]
        self._obs = init_step.get("obs", None)
        self._action_masks = init_step.get("action_masks", None)
        self._hidden_states = [None for _ in self.agents]

        self._obs = self.encode_obs(init_step.get("obs", None))
        if self.obs_last_action:
            for i, init_obs in enumerate(self._obs):
                if self.mdp.info.discrete_actions:  # make integer action one-hot
                    action = np.zeros(self.mdp.info.action_space[i].n)
                else:
                    action = np.zeros(self.mdp.info.action_space[i].shape)
                self._obs[i] = np.concatenate([init_obs, action])

        for agent in self.agents:
            agent.episode_start()
            agent.next_action = None
        self._episode_steps = 0

    def encode_obs(self, obs):
        """
        Encode the observation with the indices of the agents.

        Args:
            obs list((np.ndarray)): the observations to encode.

        Returns:
            list(np.ndarray): the one-hot encoded observations.

        """
        eye_matrix = np.eye(self.mdp.info.n_agents)
        return [np.concatenate([o, eye_matrix[i]]) for i, o in enumerate(obs)]
