"""Model-specific utility helpers (architecture-aware operations).

Contains helpers that adjust or move selective submodules for known model
families (e.g., LLaMA) after generic loading, centralizing any quirks.
"""
from __future__ import annotations

from typing import Union, Any
import torch

import constants
from transformers import PreTrainedModel  # type: ignore


def prep_model_on_device(
    model: PreTrainedModel,
    model_name: str,
    device: Union[str, torch.device, None] = constants.DEFAULT_DEVICE,
) -> PreTrainedModel:
    """Move core embedding / norm layers explicitly for certain families.

    Uses centralized layer name tuples from `constants` to avoid magic strings.
    """
    model.to(device)

    name_lower = model_name.lower()

    root = getattr(model, constants.MODEL_ROOT_ATTR, None)

    if constants.MODEL_FAMILY_LLAMA in name_lower and root is not None:
        for attr in constants.LLAMA_LAYER_ATTRS:
            layer = getattr(root, attr, None)
            if layer is not None:
                layer.to(device)

    return model


def get_model_layers(model: PreTrainedModel) -> Any:
    """Get the transformer blocks/layers from a model, supporting multiple architectures.

    This function handles different model architectures that store their transformer
    blocks in different locations:
    - LLaMA/Mistral/etc: model.model.layers
    - GPT2: model.transformer.h
    - BERT: model.encoder.layer
    - And other architectures

    Args:
        model: The HuggingFace model

    Returns:
        The list/ModuleList of transformer blocks

    Raises:
        AttributeError: If the model architecture is not supported
    """
    # Try common paths in order of likelihood
    try:
        if hasattr(model, constants.ATTR_MODEL):
            model_attr = getattr(model, constants.ATTR_MODEL)
            if hasattr(model_attr, constants.ATTR_DECODER):
                decoder_attr = getattr(model_attr, constants.ATTR_DECODER)
                if hasattr(decoder_attr, constants.ATTR_LAYERS):
                    return getattr(decoder_attr, constants.ATTR_LAYERS)
    except AttributeError:
        pass

    try:
        # LLaMA, Mistral, Qwen, etc: model.model.layers
        if hasattr(model, constants.ATTR_MODEL):
            model_attr = getattr(model, constants.ATTR_MODEL)
            if hasattr(model_attr, constants.ATTR_LAYERS):
                return getattr(model_attr, constants.ATTR_LAYERS)
    except AttributeError:
        pass

    try:
        # GPT2, GPT-J, GPT-NeoX: model.transformer.h
        if hasattr(model, constants.ATTR_TRANSFORMER):
            transformer_attr = getattr(model, constants.ATTR_TRANSFORMER)
            if hasattr(transformer_attr, constants.ATTR_H):
                return getattr(transformer_attr, constants.ATTR_H)
    except AttributeError:
        pass

    try:
        # BERT, RoBERTa: model.encoder.layer
        if hasattr(model, constants.ATTR_ENCODER):
            encoder_attr = getattr(model, constants.ATTR_ENCODER)
            if hasattr(encoder_attr, constants.ATTR_LAYER):
                return getattr(encoder_attr, constants.ATTR_LAYER)
    except AttributeError:
        pass

    try:
        # Some models: model.bert.encoder.layer
        if hasattr(model, constants.ATTR_BERT):
            bert_attr = getattr(model, constants.ATTR_BERT)
            if hasattr(bert_attr, constants.ATTR_ENCODER):
                encoder_attr = getattr(bert_attr, constants.ATTR_ENCODER)
                if hasattr(encoder_attr, constants.ATTR_LAYER):
                    return getattr(encoder_attr, constants.ATTR_LAYER)
    except AttributeError:
        pass

    # If none of the common paths work, raise an error
    model_type = type(model).__name__
    config_arch = (
        getattr(getattr(model, constants.ATTR_CONFIG), constants.ATTR_ARCHITECTURES, ['Unknown'])[0]
        if hasattr(model, constants.ATTR_CONFIG)
        else 'Unknown'
    )

    raise AttributeError(
        f"Unable to locate transformer layers for model type: {model_type} "
        f"(architecture: {config_arch}). "
        f"Please add support for this architecture in get_model_layers()."
    )


__all__ = ["prep_model_on_device", "get_model_layers"]

