from typing import Dict, Tuple, Union

from gymnasium.spaces import Discrete, MultiDiscrete, Space
from jumanji.specs import DiscreteArray, MultiDiscreteArray, Spec

_DISCRETE = "discrete"
_CONTINUOUS = "continuous"


def get_action_head(action_types: Union[Spec, Space]) -> Tuple[Dict[str, str], str]:
    """Returns the appropriate action head config based on the environment action_spec."""
    if isinstance(action_types, (DiscreteArray, MultiDiscreteArray, Discrete, MultiDiscrete)):
        return {"_target_": "mava.networks.heads.DiscreteActionHead"}, _DISCRETE

    return {"_target_": "mava.networks.heads.ContinuousActionHead"}, _CONTINUOUS
