# from .choose_model import *
from .choose_optimizer import *

from .GPT import myGPT
from .GPT_specific import myGPT_specific


def get_model(args, device, **kwargs):
    if args.model == 'GPT':
        model = myGPT(args, device).to(device)
    elif args.model == 'GPT_specific':
        model = myGPT_specific(args, device).to(device)

    return model