from models.gpt import GPT
from models.pythia import Pythia
from models.config import GPTConfig


def get_model(args):
    if args.model == 'gpt':
        config = GPTConfig(n_layers=args.n_layers, n_heads=args.n_head, n_embd=args.n_embd, block_size=args.block_size,
                           bias=True, vocab_size=args.vocab_size, dropout=0, use_flash=args.use_flash,
                           teacherless_token=args.teacherless_token, use_top=args.use_top, use_mtp=args.use_mtp,
                           n_future_tokens=args.n_future_tokens)
        model = GPT(config)

    elif args.model.startswith('gpt2'):
        model = GPT.from_pretrained(args.model, teacherless_token=args.teacherless_token)
        if args.block_size < 1024:
            model.crop_block_size(args.block_size)

    elif args.model.startswith('pythia'):
        model = Pythia.from_pretrained(args.model, teacherless_token=args.teacherless_token)

    return model
