from transformers import AutoTokenizer, AutoModelForCausalLM

pretrained_model_dic = {
        "opt":"facebook/opt-6.7b",
        }


def get_model(pretrained_model_name):
    if pretrained_model_name in pretrained_model_dic:
        print(pretrained_model_dic[pretrained_model_name])
        model = AutoModelForCausalLM.from_pretrained(pretrained_model_dic[pretrained_model_name], device_map="auto")
    else:
        print("Error: Model Type")
        
    return model



def get_tokenizer(pretrained_model_name):

    if pretrained_model_name in pretrained_model_dic:
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dic[pretrained_model_name], padding_side='left', padding=True, return_tensors='pt', truncation=True, max_length=2048)

    else:
        print("Error: Tokenizer Type")
        
    return tokenizer