import numpy as np
from k_level_policy_gradients.src.core.multi_agent_core_shared import (
    MultiAgentCoreShared,
)


class MultiAgentCoreSharedMixer(MultiAgentCoreShared):
    def __init__(self, mixer=None, **kwargs):
        """
        Constructor.

        Args:
            mixer (Mixer): the mixer to use to factorise the agents' Q values
        """
        self.mixer = mixer
        super().__init__(**kwargs)

    def _run_impl(
        self,
        move_condition,
        fit_condition_per_agent,
        steps_progress_bar,
        episodes_progress_bar,
        render,
        record,
        need_complete_episodes,
    ):
        self._total_episodes_counter = 0
        self._total_steps_counter = 0
        self._current_episodes_counter_per_agent = np.zeros(
            self.mdp.info.n_agents, dtype=int
        )
        self._current_steps_counter_per_agent = np.zeros(
            self.mdp.info.n_agents, dtype=int
        )

        dataset = []
        dataset_info = []
        last = True

        if need_complete_episodes:
            step_count_move_condition = move_condition
            move_condition = lambda: step_count_move_condition() or not last

        while move_condition():
            if last:
                self.reset()

            sample, info = self._step(render, record)

            self._total_steps_counter += 1
            self._current_steps_counter_per_agent += 1
            steps_progress_bar.update(1)

            last = sample["last"]
            if last:
                self._total_episodes_counter += 1
                self._current_episodes_counter_per_agent += 1
                episodes_progress_bar.update(1)

            dataset.append(sample)
            dataset_info.append(info)

            # Save agent datasets and fit with mixer
            if all(fit_condition() for fit_condition in fit_condition_per_agent):
                for idx_agent in range(len(self.agents)):
                    self.agents[idx_agent].fit(
                        dataset
                    )  # store the dataset in the replay memory
                    self._current_episodes_counter_per_agent[idx_agent] = 0
                    self._current_steps_counter_per_agent[idx_agent] = 0
                actor_loss, critic_loss = self.mixer.fit(dataset)
                info["actor_loss"] = actor_loss
                info["critic_loss"] = critic_loss
                dataset = list()
                dataset_info = list()

            self._get_callbacks(sample, info)

        for agent in self.agents:
            agent.stop()
        self.mdp.stop()
        if record:
            self._record.stop()

        return dataset, dataset_info
