from typing import List

from centralized_verification.agents.multi_agent import MultiAgentLearner
from centralized_verification.agents.single_agent import SingleAgentLearner
from centralized_verification.shields.shield import ShieldResult
from centralized_verification.utils import prefix_dict


class MultiAgentLearnerWrapper(MultiAgentLearner):
    def __init__(self, learners: List[SingleAgentLearner]):
        self.learners = learners

    def observe_transition(self, joint_obs, shield_result: ShieldResult, joint_next_obs, joint_rew, done, step_num,
                           training_progress):
        for learner, obs, agent_shield_result, next_obs, rew in zip(self.learners, joint_obs, shield_result,
                                                                    joint_next_obs, joint_rew):
            learner.observe_transition(obs, agent_shield_result, next_obs, rew, done, step_num, training_progress)

    def get_joint_action(self, joint_observation, training_progress):
        return tuple(
            learner.get_action(obs, training_progress) for learner, obs in zip(self.learners, joint_observation))

    def num_agents(self) -> int:
        return len(self.learners)

    def state_dict(self):
        return {
            "learners": [learner.state_dict() for learner in self.learners]
        }

    def load_state_dict(self, state_dict):
        for learner, lsd in zip(self.learners, state_dict["learners"]):
            learner.load_state_dict(lsd)

    def get_log_dict(self):
        log_dict = {}
        for i, learner in enumerate(self.learners):
            agent_log_dict = learner.get_log_dict()
            agent_log_dict_prefixed = prefix_dict(agent_log_dict, f"learner_{i}/")
            log_dict.update(agent_log_dict_prefixed)

        return log_dict
