from typing import Tuple

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.shields.shield import Shield, ShieldResult, AgentResult, AgentUpdate, T
from centralized_verification.shields.utils import is_safe_action


class NoShield(Shield[MultiAgentSafetyEnv, None]):
    def get_initial_shield_state(self, state, initial_joint_obs) -> T:
        return None

    def evaluate_joint_action(self, state, _, proposed_action, __) -> Tuple[ShieldResult, None]:
        if self.punish_unsafe_orig_action and not is_safe_action(self.env, state, proposed_action):
            return [AgentResult(AgentUpdate(action=action, reward_modifier=self.unsafe_action_punishment)) for action in
                    proposed_action], None
        else:
            return [AgentResult(AgentUpdate(action=action)) for action in proposed_action], None
