from transformers import AutoModelForCausalLM, AutoTokenizer
DEFAULT_PAD_TOKEN = "[PAD]"


def resize_token_embeddings(tokenizer, model):
    extra_token_count = len(tokenizer) - model.get_input_embeddings().weight.data.size(0)
    if extra_token_count:
        model.resize_token_embeddings(len(tokenizer))

        input_embeddings = model.get_input_embeddings().weight.data

        input_embeddings[-extra_token_count:] = input_embeddings[
            :-extra_token_count
        ].mean(dim=0, keepdim=True)

        output_embeddings = model.get_output_embeddings().weight.data

        output_embeddings[-extra_token_count:] = output_embeddings[
            :-extra_token_count
        ].mean(dim=0, keepdim=True)


def create_tokenizer(model_name: str,
                     padding_side: str = "right",
                     model_max_length: int = 2048):
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side=padding_side,
        model_max_length=model_max_length)
    
    if "Llama" in model_name: # Reset pad_token for Llama models??
        tokenizer.add_special_tokens({"pad_token": DEFAULT_PAD_TOKEN})
    
    # Qwen 2.5: pad_token == eos_token (<|eos_token|>)
    # Qwen 3: pad_token (<|eos_token|>) != eos_token (<|im_end|>)
    
    return tokenizer
    
    
def create_model(model_name: str,
                 torch_dtype: str = "auto",
                 attn_implementation: str = None,
                 device_map: str = "auto",
                 model_max_length: int = 2048
                 ):
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch_dtype,
        attn_implementation=attn_implementation,  
        device_map=device_map,
        )
    tokenizer = create_tokenizer(model_name,
                                 padding_side="left" if "Qwen" in model_name and "flash_attention_2" == attn_implementation else "right",
                                 model_max_length=model_max_length)
    model.config.pad_token_id = tokenizer.pad_token_id 
    
    if "Llama" in model_name:
        resize_token_embeddings(tokenizer, model)
    
    return model, tokenizer