import torch
from huggingface_hub import list_repo_files
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union


def load_hf_model(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map: str = "cuda",
    attn_implementation: Union[str, None] = "eager",
    requires_grad: bool = False,
    use_cache: bool = False,  # Disable KV cache by default to avoid device mismatch errors
):
    # Choose attention implementation if not specified
    if attn_implementation is None:
        # Make sure that models that dont support FlashAttention aren't forced to use it
        if "gpt2" in model_name or "gemma" in model_name:
            attn_implementation = "eager"
        else:
            attn_implementation = "flash_attention_2"

    # Check if the model is peft, and load accordingly
    files = list_repo_files(model_name)
    has_adapter_config = any("adapter_config.json" in file for file in files)
    if has_adapter_config:
        model = (
            AutoPeftModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=True,
                attn_implementation=attn_implementation,
                device_map=device_map,
                trust_remote_code=True,
                use_cache=use_cache,
            )
            .merge_and_unload()
            .eval()
        )

    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True,
            attn_implementation=attn_implementation,
            device_map=device_map,
            trust_remote_code=True,
            use_cache=use_cache,
        ).eval()

    # Disable model grad if we're not training
    if not requires_grad:
        model.requires_grad_(False)

    # Save and return the model
    return model
