"""Registry of model and action dist names."""

from models.communication import AttComModel, InvariantAttComModel, AttComActionMaskModel, InvariantAttComActionMaskModel
from models.communication.action_dist import TorchHomogeneousMultiActionDistribution
from models.torch.action_mask_model import TorchActionMaskModel
from models.tf.action_mask_model import TFActionMaskModel
from models.tf.fcnet import FullyConnectedNetwork as TFFullyConnectedNetwork
from models.tf.transformer import TransformerNetwork as TFTransformerNetwork
from models.tf.tf_nmmo_model import TFNMMOModel

MODELS = {
    "action_mask_model": TorchActionMaskModel,
    "att_com_model": AttComModel,
    "invariant_att_com_model": InvariantAttComModel,
    "att_com_action_mask_model": AttComActionMaskModel,
    "invariant_att_com_action_mask_model": InvariantAttComActionMaskModel,
    "tf_action_mask_model": TFActionMaskModel,
    "tf_fcnet": TFFullyConnectedNetwork,
    "tf_transformer": TFTransformerNetwork,
    "tf_nmmo_model": TFNMMOModel
}


def get_model_class(model: str) -> type:
    """Returns the class of a known model given its name."""

    if model in MODELS:
        class_ = MODELS[model]
    else:
        raise Exception(f"Unknown model {model}.")

    return class_


ACTION_DISTS = {
    "hom_multi_action": TorchHomogeneousMultiActionDistribution,
}


def get_action_dist_class(action_dist: str) -> type:
    """Returns the class of a known model given its name."""

    if action_dist in ACTION_DISTS:
        class_ = ACTION_DISTS[action_dist]
    else:
        raise Exception(f"Unknown action distribution {action_dist}.")

    return class_
