"""
Local HuggingFace Transformers backend.

Loads any causal LM from the Hub (or a local path) and generates via the
model's chat template.  Supports 4-bit quantisation via bitsandbytes.

Install: pip install transformers accelerate
Optional: pip install bitsandbytes  (for load_in_4bit=True)
"""
from __future__ import annotations

import torch

from meta_rg.backends.base import BaseBackend


class HFBackend(BaseBackend):
    def __init__(
        self,
        model_id: str,
        temperature: float = 0.7,
        max_new_tokens: int = 256,
        load_in_4bit: bool = False,
        device_map: str = "auto",
        system_prompt: str | None = None,
        **_,
    ) -> None:
        try:
            from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
        except ImportError as e:
            raise ImportError("pip install transformers accelerate") from e

        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.system_prompt = system_prompt

        quant_cfg = None
        if load_in_4bit:
            quant_cfg = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            dtype=torch.bfloat16,
            device_map=device_map,
            quantization_config=quant_cfg,
        )
        self.model.eval()

    def set_seed(self, seed: int) -> None:
        torch.manual_seed(seed)

    @property
    def tokenize_fn(self):
        tok = self.tokenizer
        return lambda text: len(tok.encode(text, add_special_tokens=False))

    def generate(self, prompt_text: str) -> str:
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": prompt_text})

        formatted = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(formatted, return_tensors="pt", truncation=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        prompt_len = inputs["input_ids"].shape[1]

        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=True,
                temperature=self.temperature,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        gen_ids = output_ids[0, prompt_len:]
        return self.tokenizer.decode(gen_ids, skip_special_tokens=True)

    def generate_chat(self, messages: list[dict]) -> str:
        formatted = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(formatted, return_tensors="pt", truncation=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        prompt_len = inputs["input_ids"].shape[1]
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=True,
                temperature=self.temperature,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        return self.tokenizer.decode(output_ids[0, prompt_len:], skip_special_tokens=True)
