from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import logging

logger = logging.getLogger(__name__)

def load_model(cfg):
    names = ['gpt2', 'gpt2-xl', 'meta-llama/Llama-2-7b-hf']
    assert cfg.name in names
    model = AutoModelForCausalLM.from_pretrained(cfg.name).to('cuda')
    tokenizer = AutoTokenizer.from_pretrained(cfg.name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.requires_grad_(False)
    sim = SentenceTransformer(cfg.sim).to('cuda')
    logger.info(f'Load model: {cfg.name}')
    logger.info(f'Load similarity: {cfg.sim}')
    return model, tokenizer, sim
