from transformers import AutoTokenizer


def get_torch_dtype(type_: str):
    import torch

    return getattr(torch, type_)


def get_tokenizer_with_pad_token_id(pretrained_model_name_or_path, **kwargs):
    # if (
    #     pretrained_model_name_or_path
    #     == "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-unsup-simcse"
    # ):
    #     pretrained_model_name_or_path = (
    #         "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp"
    #     )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
    if tokenizer.eos_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer


def load_llm2vec(
    pretrained_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
    *args,
    **kwargs,
):
    import torch
    from transformers import AutoConfig, AutoModel
    from peft.peft_model import PeftModel

    config = AutoConfig.from_pretrained(
        "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp", trust_remote_code=True
    )
    model = AutoModel.from_pretrained(
        "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
        trust_remote_code=True,
        config=config,
        torch_dtype=torch.bfloat16,
        device_map="cuda" if torch.cuda.is_available() else "cpu",
    )
    model = PeftModel.from_pretrained(
        model,
        "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
    )
    model = model.merge_and_unload()

    # Loading unsupervised SimCSE model. This loads the trained LoRA weights on top of MNTP model. Hence the final weights are -- Base model + MNTP (LoRA) + SimCSE (LoRA).
    model = PeftModel.from_pretrained(model, pretrained_model_name_or_path)
    model = model.merge_and_unload()
    return model
