from typing import List, Dict, Any, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def load(name: str):
    tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
    if not tokenizer.pad_token:
        tokenizer.pad_token = "<|finetune_right_pad_id|>"
    model_args = dict(
        attn_implementation="flash_attention_2", torch_dtype=torch.float16
    )
    model = AutoModelForCausalLM.from_pretrained(name, **model_args)
    return model, tokenizer
