from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

from .base_client import ModelResponseBase


class HFAgent(ModelResponseBase):
    def __init__(self, model_name_or_path, device=None, **kwargs):
        super().__init__(
            name=model_name_or_path,
            in_token_costs=0.0,
            out_token_costs=0.0,
            api_config=kwargs.get("api_config", {}),
        )
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
        )
        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device=0 if self.device == "cuda" else -1,
            do_sample=kwargs.get("do_sample", True),
            temperature=kwargs.get("temperature", 0.7),
            top_p=kwargs.get("top_p", 0.95),
        )

    @property
    def client(self):
        pass

    @property
    def config(self):
        return self.model.config

    async def query_response(self, prompt, apply_chat_template=False, **kwargs):
        """
        msg: str or list of dicts (role/content)
        system_msg: str or None
        """
        if apply_chat_template:
            prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False)
        # Run generation (sync, but can be wrapped in asyncio)
        # If you want to run this in a thread pool for true async, use asyncio.to_thread
        import asyncio

        result = await asyncio.to_thread(self.pipeline, prompt, **kwargs)
        # result is a list of dicts with 'generated_text'
        generated = []
        for item in result:
            if isinstance(item, list):
                generated.append(item[0]["generated_text"][len(prompt[0]) :].strip())
            else:
                generated.append(item["generated_text"][len(prompt) :].strip())
        return generated


huggingface_model_map = {"hf": HFAgent}

if __name__ == "__main__":
    import asyncio

    async def unittest():
        messages = [
            {"role": "user", "content": "What is the capital of France?"},
        ]
        # For HuggingFace model, pass the model name or path
        model = HFAgent("meta-llama/Llama-3.2-1B-Instruct")
        out = await model.query_response(messages)
        print(out)

    asyncio.run(unittest())
