from __future__ import annotations
from typing import Optional
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from .scenarios import Scenario, META_PROMPTS
from langchain_deepseek import ChatDeepSeek


def _pick_device(dev: str) -> int:
    """Return HF pipeline device index: 0 for CUDA if available, -1 for CPU."""
    if dev == "auto":
        return 0 if torch.cuda.is_available() else -1
    return 0 if dev.startswith("cuda") else -1


class PromptRefiner:
    def __init__(
        self,
        use_refine_llm: bool,
        gen_model: Optional[str] = None,
        device: str = "auto",
        max_new_tokens: int = 220
    ):
        self.use_refine_llm = use_refine_llm
        self.max_new_tokens = max_new_tokens
        self.meta = META_PROMPTS
        self.device = _pick_device(device)
        self.generator = None

        # Initialize local generator if needed
        if not self.use_refine_llm and gen_model is not None:
            print(f"[PromptRefiner] Loading local generator model: {gen_model}")
            tok = AutoTokenizer.from_pretrained(gen_model)
            mdl = AutoModelForCausalLM.from_pretrained(
                gen_model,
                torch_dtype=torch.float16 if self.device == 0 else torch.float32
            )
            self.generator = pipeline(
                "text-generation",
                model=mdl,
                tokenizer=tok,
                device=self.device
            )

    def build_prompt(self, scenario: Scenario, user_prompt: str) -> str:
        """Get the system meta-prompt for a given scenario."""
        if scenario not in self.meta:
            raise ValueError(f"[PromptRefiner] No meta-prompt defined for scenario: {scenario}")
        meta_prompt = self.meta[scenario]
        return meta_prompt.format(user=user_prompt.strip()) if "{user}" in meta_prompt else meta_prompt

    def refine(self, scenario: Scenario, user_prompt: str) -> str:
        """
        Refine the user prompt based on scenario-specific meta-prompts.
        """
        print(f"[PromptRefiner] Scenario: {scenario.value}")
        meta_prompt = self.build_prompt(scenario, user_prompt)

        if self.use_refine_llm:
            print("[PromptRefiner] Using external LLM for refinement.")
            llm = ChatDeepSeek(
                model="deepseek-chat",
                temperature=0,
                max_tokens=None,
                timeout=None,
                max_retries=2,
            )
            query = [
                ("system", meta_prompt),
                ("human", user_prompt),
            ]
            raw = llm.invoke(query)
            return raw.content.strip()

        if not self.use_refine_llm and self.generator:
            print("[PromptRefiner] Using local causal LM for refinement.")
            out = self.generator(
                meta_prompt,
                max_new_tokens=self.max_new_tokens,
                do_sample=True,
                temperature=0.7
            )
            return out[0]["generated_text"].strip()


        print("[PromptRefiner] Returning meta-prompt (no generator).")
        return meta_prompt.strip()
