from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase

try:
    from transformers import BitsAndBytesConfig
except ImportError:
    BitsAndBytesConfig = None
from pruning.ffn_mask import wrap_model_with_maskable_ffn


@dataclass
class Llama3ModelBundle:
    tokenizer: PreTrainedTokenizerBase
    model: PreTrainedModel


def _resolve_dtype(name_or_dtype: Any) -> torch.dtype:
    if isinstance(name_or_dtype, torch.dtype):
        return name_or_dtype
    if isinstance(name_or_dtype, str):
        normalized = name_or_dtype.lower()
        if normalized in {"bfloat16", "bf16"}:
            return torch.bfloat16
        if normalized in {"float16", "fp16", "half"}:
            return torch.float16
        if normalized in {"float32", "fp32"}:
            return torch.float32
        raise ValueError(f"Unsupported torch dtype string: {name_or_dtype}")
    raise TypeError(f"Expected torch.dtype or string, got {type(name_or_dtype)}")


def load_llama3_model(config: Dict[str, Any]) -> Llama3ModelBundle:
    base_model = config.get("base_model")
    if base_model is None:
        raise KeyError("`base_model` must be specified in the configuration.")

    trust_remote_code = bool(config.get("trust_remote_code", False))
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code)
    if tokenizer.pad_token is None:
        fallback = tokenizer.eos_token or tokenizer.unk_token
        if fallback is not None:
            tokenizer.pad_token = fallback

    load_in_8bit = bool(config.get("load_in_8bit", False))
    attn_impl = config.get("attn_implementation")
    max_position_embeddings = config.get("max_position_embeddings")

    model_kwargs: Dict[str, Any] = {}
    if load_in_8bit:
        if BitsAndBytesConfig is None:
            raise ImportError(
                "BitsAndBytesConfig is unavailable. Please upgrade `transformers` or "
                "install the optional bitsandbytes dependency to enable 8-bit loading."
            )
        bnb_kwargs = config.get("bitsandbytes", {})
        quant_config = BitsAndBytesConfig(load_in_8bit=True, **bnb_kwargs)
        model_kwargs["quantization_config"] = quant_config
        model_kwargs["device_map"] = config.get("device_map", "auto")
    else:
        torch_dtype = _resolve_dtype(config.get("torch_dtype", torch.float16))
        model_kwargs["torch_dtype"] = torch_dtype
    if attn_impl is not None:
        model_kwargs["attn_implementation"] = attn_impl

    if max_position_embeddings:
        hf_config = AutoConfig.from_pretrained(base_model, trust_remote_code=trust_remote_code)
        hf_config.max_position_embeddings = max_position_embeddings
    else:
        hf_config = None

    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        config=hf_config,
        trust_remote_code=trust_remote_code,
        **model_kwargs,
    )
    use_cache = config.get("use_cache")
    if use_cache is not None:
        model.config.use_cache = bool(use_cache)
        if hasattr(model, "generation_config"):
            model.generation_config.use_cache = bool(use_cache)
    wrap_model_with_maskable_ffn(model)

    if not load_in_8bit:
        device = torch.device(config.get("device", "cuda"))
        model.to(device)
    model.eval()
    return Llama3ModelBundle(tokenizer=tokenizer, model=model)
