from typing import Tuple

from policy.policy import PolicySpec
from policy.sample_batch import DEFAULT_POLICY_ID
from utils.from_config import from_config
from utils.typing import MultiAgentPolicyConfigDict, PartialTrainerConfigDict


def check_multi_agent(
    config: PartialTrainerConfigDict,
) -> Tuple[MultiAgentPolicyConfigDict, bool]:
    """Checks, whether a (partial) config defines a multi-agent setup.

    Args:
        config: The user/Trainer/Policy config to check for multi-agent.

    Returns:
        Tuple consisting of the resulting (all fixed) multi-agent policy
        dict and bool indicating whether we have a multi-agent setup or not.
    """
    multiagent_config = config["multiagent"]
    policies = multiagent_config.get("policies")

    # Nothing specified in config dict -> Assume simple single agent setup
    # with DEFAULT_POLICY_ID as only policy.
    if not policies:
        policies = {DEFAULT_POLICY_ID}
    # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy
    # automatically via empty PolicySpec (will make RLlib infer obs- and action spaces
    # as well as the Policy's class).
    if isinstance(policies, (set, list, tuple)):
        policies = multiagent_config["policies"] = {
            pid: PolicySpec() for pid in policies
        }
    # Attempt to create a `policy_mapping_fn` from config dict. Helpful
    # is users would like to specify custom callable classes in yaml files.
    if isinstance(multiagent_config.get("policy_mapping_fn"), dict):
        multiagent_config["policy_mapping_fn"] = from_config(
            multiagent_config["policy_mapping_fn"]
        )
    elif isinstance(multiagent_config.get("policy_mapping_fn"), list):
        for i in range(len(multiagent_config["policy_mapping_fn"])):
            multiagent_config["policy_mapping_fn"][i] = from_config(
                multiagent_config["policy_mapping_fn"][i]
            )
    # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
    # PolicyID found in policies dict.
    is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
    return policies, is_multiagent
