"""
llama-cpp-python backend — GGUF model inference with optional ROCm/HIP GPU acceleration.

Install (CPU):
  pip install llama-cpp-python

Install (AMD ROCm/HIP — gfx1151 / RDNA4):
  CMAKE_ARGS="-DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1151 -DGGML_HIP_ROCWMMA_FATTN=ON" \\
    pip install llama-cpp-python==0.3.4 --force-reinstall --no-cache-dir

Install (TurboQuant KV-cache fork — experimental):
  CMAKE_ARGS="-DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1151 -DGGML_HIP_ROCWMMA_FATTN=ON" \\
    pip install "git+https://github.com/Pascal-SAPUI5/llama.cpp-turboquant.git" \\
      --force-reinstall --no-cache-dir

Usage example (HuggingFace cache — recommended):
  cfg = {
    "backend_type": "llamacpp",
    "hf_repo_id": "bartowski/Llama-3.2-1B-Instruct-GGUF",
    "hf_filename": "Llama-3.2-1B-Instruct-Q4_K_M.gguf",
    "n_ctx": 8192,
    "n_gpu_layers": -1,   # -1 = offload ALL layers to GPU
    "flash_attn": True,   # rocWMMA Flash Attention (significant speedup on RDNA4)
    "max_tokens": 64,
    # TurboQuant KV cache (requires TurboQuant fork):
    # "type_k": 14,  # GGML_TYPE_TURBO3_0 — verify enum with llama_cpp.GGML_TYPE
    # "type_v": 14,
  }
"""
from __future__ import annotations

from meta_rg.backends.base import BaseBackend
from meta_rg.exceptions import ContextOverflowError


def _resolve_model_path(
    model_path: str | None,
    hf_repo_id: str | None,
    hf_filename: str | None,
) -> str:
    if hf_repo_id and hf_filename:
        try:
            from huggingface_hub import hf_hub_download
        except ImportError as e:
            raise ImportError("pip install huggingface_hub") from e
        return hf_hub_download(repo_id=hf_repo_id, filename=hf_filename)
    if model_path:
        return model_path
    raise ValueError("Provide either hf_repo_id+hf_filename or model_path")


class LlamaCppBackend(BaseBackend):
    def __init__(
        self,
        model_path: str | None = None,
        hf_repo_id: str | None = None,
        hf_filename: str | None = None,
        n_ctx: int = 4096,
        n_gpu_layers: int = -1,
        temperature: float = 0.7,
        max_tokens: int = 256,
        system_prompt: str | None = None,
        verbose: bool = False,
        flash_attn: bool = True,
        n_threads: int | None = None,
        type_k: int | None = None,
        type_v: int | None = None,
        **_,
    ) -> None:
        try:
            from llama_cpp import Llama
        except ImportError as e:
            raise ImportError(
                "pip install llama-cpp-python\n"
                "  (ROCm/HIP: CMAKE_ARGS=\"-DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1151 "
                "-DGGML_HIP_ROCWMMA_FATTN=ON\" pip install llama-cpp-python)"
            ) from e

        self.temperature = temperature
        self.max_tokens = max_tokens
        self.system_prompt = system_prompt

        resolved_path = _resolve_model_path(model_path, hf_repo_id, hf_filename)

        kwargs: dict = dict(
            model_path=resolved_path,
            n_ctx=n_ctx,
            n_gpu_layers=n_gpu_layers,
            verbose=verbose,
            flash_attn=flash_attn,
        )
        if n_threads is not None:
            kwargs["n_threads"] = n_threads
        if type_k is not None:
            kwargs["type_k"] = type_k
        if type_v is not None:
            kwargs["type_v"] = type_v

        self._llm = Llama(**kwargs)

    @property
    def tokenize_fn(self):
        llm = self._llm
        return lambda text: len(llm.tokenize(text.encode()))

    def generate(self, prompt_text: str) -> str:
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": prompt_text})
        return self._create_completion(messages)

    def generate_chat(self, messages: list[dict]) -> str:
        """Pass the full message list directly to llama.cpp chat completion."""
        return self._create_completion(messages)

    def _create_completion(self, messages: list[dict]) -> str:
        try:
            response = self._llm.create_chat_completion(
                messages=messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
            )
            return response["choices"][0]["message"]["content"] or ""
        except ValueError as exc:
            if "exceed context window" in str(exc):
                raise ContextOverflowError(str(exc)) from exc
            raise
