from k_level_policy_gradients.src.core.multi_agent_core_hidden import (
    MultiAgentCoreHidden,
)
from k_level_policy_gradients.src.core.environment import Environment
import numpy as np


class MultiAgentCoreHiddenMixer(MultiAgentCoreHidden):
    """
    Multi-agent core with hidden states for recurrent processing.
    Each agent network propagates a hidden state through time to be used
    in the next step to select actions
    """

    def __init__(
        self,
        agents: list = None,
        mdp: Environment = None,
        mixer=None,
        callbacks_step: list = [],
        preprocessors: list = [],
        record_dictionary=None,
    ):
        """
        Constructor.

        Args:
            agents (Agent): list of agents moving according to a policy;
            mixer (Mixer): the mixer to use to factorise the agents' Q values
            mdp (Environment): the environment in which the agent moves;
            callbacks_step (list): list of callbacks to execute after each step
        """
        self.agents = agents
        self.mixer = mixer
        self.mdp = mdp
        self.callbacks_step = callbacks_step
        self._preprocessors = preprocessors

        self._state = None
        self._obs = None
        self._action_masks = None

        self._total_episodes_counter = 0
        self._total_steps_counter = 0
        self._current_episodes_counter_per_agent = np.zeros(
            self.mdp.info.n_agents, dtype=int
        )
        self._current_steps_counter_per_agent = np.zeros(
            self.mdp.info.n_agents, dtype=int
        )
        self._episode_steps = None

        if record_dictionary is None:
            record_dictionary = dict()
        self._record = self._build_recorder_class(**record_dictionary)

    def _run_impl(
        self,
        move_condition,
        fit_condition_per_agent,
        steps_progress_bar,
        episodes_progress_bar,
        render,
        record,
        need_complete_episodes,
    ):
        self._total_episodes_counter = 0
        self._total_steps_counter = 0
        self._current_episodes_counter_per_agent = np.zeros(
            self.mdp.info.n_agents, dtype=int
        )
        self._current_steps_counter_per_agent = np.zeros(
            self.mdp.info.n_agents, dtype=int
        )

        dataset = []
        dataset_info = []
        last = True

        if need_complete_episodes:
            step_count_move_condition = move_condition
            move_condition = lambda: step_count_move_condition() or not last

        while move_condition():
            if last:
                self.reset()

            sample, info = self._step(render, record)

            self._total_steps_counter += 1
            self._current_steps_counter_per_agent += 1
            steps_progress_bar.update(1)

            last = sample["last"]
            if last:
                self._total_episodes_counter += 1
                self._current_episodes_counter_per_agent += 1
                episodes_progress_bar.update(1)

            dataset.append(sample)
            dataset_info.append(info)

            # Save agent datasets and fit with mixer
            if all(fit_condition() for fit_condition in fit_condition_per_agent):
                for idx_agent in range(len(self.agents)):
                    self.agents[idx_agent].fit(
                        dataset
                    )  # store the dataset in the replay memory
                    self._current_episodes_counter_per_agent[idx_agent] = 0
                    self._current_steps_counter_per_agent[idx_agent] = 0
                self.mixer.fit(dataset)
                dataset = list()
                dataset_info = list()

            self._get_callbacks(sample, info)

        for agent in self.agents:
            agent.stop()
        self.mdp.stop()
        if record:
            self._record.stop()

        return dataset, dataset_info
