import torch.nn as nn

from src.model.hsnet import HypercorrSqueezeNetwork
from src.model.hsnet_transformer import TransformerHypercorrSqueezeNetwork

def get_model(args) -> nn.Module:
    model_type = 'pspnet'
    if hasattr(args, 'model_type'):
        model_type = args.model_type

    if model_type == 'hsnet':
        return HypercorrSqueezeNetwork(args=args, backbone=args.layers, use_original_imgsize=True)
    elif model_type == 'hsnet_transformer':
        return TransformerHypercorrSqueezeNetwork(args=args, backbone=args.layers, use_original_imgsize=True)
    else:
        raise NotImplementedError()



