"""Model utils."""

from typing import Tuple

import models
import tensorflow as tf


def get_simple_mlp(
    input_shape,
    num_classes
):
  """Gets simple MLP."""
  model = models.custom_model.SimpleMLP(input_shape, num_classes)
  return model


def get_simple_convnet(
    input_shape,
    num_classes
):
  """Gets simple ConvNet."""
  model = models.custom_model.SimpleConvNet(input_shape, num_classes)
  return model


def get_cifar_resnet(
    input_shape,
    num_classes
):
  """Gets CifarResNet."""
  model = models.custom_model.CifarResNet(input_shape, num_classes)
  return model


def get_densenet121(
    input_shape,
    num_classes,
    weights = 'imagenet'
):
  """Gets DenseNet121."""
  model = models.custom_model.DenseNet(
      input_shape=input_shape,
      num_classes=num_classes,
      weights=weights,
      densenet_name='DenseNet121'
  )
  return model


def get_resnet50(
    input_shape,
    num_classes,
    weights = 'imagenet'
):
  """Gets ResNet50."""
  model = models.custom_model.ResNet(
      input_shape=input_shape,
      num_classes=num_classes,
      weights=weights,
      resnet_name='ResNet50'
  )
  return model


def get_roberta_mlp(
    input_shape,
    num_classes
):
  """Gets RoBerta MLP."""
  model = models.custom_model.RoBertaMLP(input_shape, num_classes)
  return model
