import torch
import torchvision
from transformers import OPTForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling, LlamaForCausalLM

def get_model(model_name, weights):
    model = None
    if model_name == 'vit_l_16':
        model = torchvision.models.vit_l_16(weights=weights)
    elif model_name == 'vit_b_16':
        model = torchvision.models.vit_b_16(weights=weights)
    elif model_name == 'resnet152':
        model = torchvision.models.resnet152(weights=weights)
    elif model_name == 'resnet50':
        model = torchvision.models.resnet50(weights=weights)
    elif model_name == 'opt':
        model = OPTForCausalLM.from_pretrained(weights)
    elif model_name == 'llama':
        model = LlamaForCausalLM.from_pretrained(weights)
    else:
        print('Model not registered!')
    return model