"""Model loader implementations for supported backends.

Provides unified entry points for loading Hugging Face causal LMs and timm ViT
models, handling device/dtype placement and basic preparation.
"""
from __future__ import annotations

from typing import Optional, Tuple

import torch
from torch import nn

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)

import constants
from models.model_utils import prep_model_on_device


def _freeze_model_params(model: nn.Module) -> None:
    """Disable gradients for all model parameters (inference / PTQ)."""
    for p in model.parameters():
        p.requires_grad = False


def load_hf_causal_lm(
    model_name: str,
    device: torch.device,
    base_dtype:torch.dtype,
    *,
    model_path: Optional[str] = None,
    use_model_path: bool = False,
    revision: Optional[str] = None,
    trust_remote_code: bool = True,
    low_cpu_mem_usage: bool = True,
    load_in_8bit: bool = False,
    load_in_4bit: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
    # Determine which path to use for loading
    model_source = model_path if use_model_path and model_path else model_name

    tokenizer = AutoTokenizer.from_pretrained(
        model_source,
        revision=revision,
        trust_remote_code=trust_remote_code,
    )

    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    load_kwargs = {
        "pretrained_model_name_or_path": model_source,
        "revision": revision,
        "dtype": base_dtype,
        "trust_remote_code": trust_remote_code,
        "low_cpu_mem_usage": low_cpu_mem_usage,
    }
    if load_in_8bit:
        load_kwargs["load_in_8bit"] = True
    if load_in_4bit:
        load_kwargs["load_in_4bit"] = True

    model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(**load_kwargs)
    model = prep_model_on_device(model, model_name, device)
    model.eval()
    _freeze_model_params(model)
    return model, tokenizer


def load_model(
    model_name: str,
    device: torch.device ,
    base_dtype: torch.dtype,
    *,
    model_path: Optional[str] = None,
    use_model_path: bool = False,
    hf_revision: Optional[str] = None,
    hf_trust_remote_code: bool = True,
    hf_low_cpu_mem_usage: bool = True,
    hf_load_in_8bit: bool = False,
    hf_load_in_4bit: bool = False,
) -> Tuple[object, ...]:
    """Loader for HF causal LMs."""
    return load_hf_causal_lm(
        model_name=model_name,
        device=device,
        base_dtype=base_dtype,
        model_path=model_path,
        use_model_path=use_model_path,
        revision=hf_revision,
        trust_remote_code=hf_trust_remote_code,
        low_cpu_mem_usage=hf_low_cpu_mem_usage,
        load_in_8bit=hf_load_in_8bit,
        load_in_4bit=hf_load_in_4bit,
    )


__all__ = [
    "load_model",
    "load_hf_causal_lm",
]