import torch.nn as nn
from algs.utilities.SequentialNetwork import SequentialNetwork

def create_model(input_dim, model_structure, output_dim, activation=None):
    if model_structure == '512256':
        return SequentialNetwork([input_dim, 512, 256, output_dim], nn.ReLU(), activation)
    elif model_structure == '512256128':
        return SequentialNetwork([input_dim, 512, 256, 128, output_dim], nn.ReLU(), activation)
    else:
        return SequentialNetwork([input_dim, 512, 256, output_dim], nn.ReLU(), activation)
