import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Type

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

from ..utils import str_to_dtype

logger = logging.getLogger(__name__)


def get_model_path(model_name_or_path: str) -> str:
    """Helper function to get the model path. Useful for Hugging Face offline mode."""
    if os.getenv("HF_HUB_OFFLINE") != "1":
        return model_name_or_path

    if Path(model_name_or_path).exists():
        return model_name_or_path

    hf_home = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
    model_cache = hf_home / "hub" / f"models--{model_name_or_path.replace('/', '--')}" / "snapshots"
    snapshot_dir = next(p for p in model_cache.iterdir() if p.is_dir())

    return str(snapshot_dir)


def load_tokenizer(
    model_name_or_path: str,
    tokenizer_cls: Type[PreTrainedTokenizer | PreTrainedTokenizerFast] = AutoTokenizer,
    tokenizer_kwargs: dict[str, Any] | None = None,
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
    """Helper function to load a tokenizer."""
    if tokenizer_kwargs is None:
        tokenizer_kwargs = {}
    tokenizer = tokenizer_cls.from_pretrained(get_model_path(model_name_or_path), **tokenizer_kwargs)
    tokenizer.pad_token = tokenizer.eos_token  # standard in causal language modeling
    return tokenizer


def load_pretrained_model(
    pretrained_model_name_or_path: str | None = None,
    pretrained_model_cls: Type[PreTrainedModel] = AutoModelForCausalLM,
    pretrained_model_kwargs: dict[str, Any] | None = None,
    generation_config: GenerationConfig | None = None,
) -> PreTrainedModel:
    """Helper function to load a pretrained model."""
    if pretrained_model_kwargs is None:
        pretrained_model_kwargs = {}
    if pretrained_model_name_or_path is not None:
        logger.debug(f"Loading pretrained model from {pretrained_model_name_or_path}.")
        model = pretrained_model_cls.from_pretrained(pretrained_model_name_or_path, **pretrained_model_kwargs)
    else:
        logger.debug("Creating pretrained model from scratch.")
        model = pretrained_model_cls(**pretrained_model_kwargs)

    if generation_config is not None:
        model.generation_config = generation_config

    return model


def save_pretrained_model(
    model: PreTrainedModel,
    save_path: str,
    include_filter: list[str] | None = None,
    exclude_filter: list[str] | None = None,
    save_kwargs: dict[str, Any] | None = None,
):
    """Helper function to save a pretrained model."""
    if save_kwargs is None:
        save_kwargs = {}
    model.save_pretrained(save_path, include_filter=include_filter, exclude_filter=exclude_filter, **save_kwargs)
    logger.info(f"Model saved to {save_path}.")


@dataclass
class BaseModelConfig:
    identifier: str
    base_model_name_or_path: str
    ctx_length: int = 2048
    # For EntQuant, this determines the device placement for the model BEFORE compression.
    device_map: str | dict[str, str | torch.device | int] = "auto"
    dtype: torch.dtype | str = torch.bfloat16
    pretrained_model_cls: Type[PreTrainedModel | AutoModelForCausalLM] = AutoModelForCausalLM
    pretrained_model_kwargs: dict[str, Any] = field(default_factory=dict)
    generation_config: GenerationConfig | None = None
    tokenizer_cls: Type[PreTrainedTokenizer | PreTrainedTokenizerFast | AutoTokenizer] = AutoTokenizer
    tokenizer_kwargs: dict[str, Any] = field(default_factory=lambda: {"use_fast": True, "padding_side": "left"})

    def __post_init__(self):
        if isinstance(self.dtype, str) and self.dtype != "auto":
            self.dtype = str_to_dtype(self.dtype)
