import copy

from expground.types import Dict, Any
from expground.common.models import torch as torch_models


def get_model(model_config: Dict[str, Any], framework="torch"):
    model_type = model_config.get("network", "mlp")
    module = torch_models if framework == "torch" else None
    assert module is not None, "Tensorflow not supported now!"

    if model_type == "mlp":
        handler = module.MLP
    elif model_type == "rnn":
        handler = module.RNN
    elif model_type == "cnn":
        handler = module.Vision
    elif model_type == "rcnn":
        raise NotImplementedError
    else:
        raise NotImplementedError

    def builder(observation_space, action_space, use_cuda=False):
        model = handler(observation_space, action_space, copy.deepcopy(model_config))
        if use_cuda:
            model.cuda()
        return model

    return builder
