import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from any_precision import AnyPrecisionForCausalLM

import os
data_root = os.environ.get("DATA_ROOT")

class Agent:
    def __init__(self, model_path, is_anyprecision=False, device="cuda:5"):
        if is_anyprecision:
            self.model = AnyPrecisionForCausalLM.from_quantized(model_path, precisions=[2,3,4,5,6,7,8]).to(device)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(device)
            # print(self.model)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.is_anyprecision = is_anyprecision

    def __call__(self, prompt, do_sample=True, precision=8):
        messages = [
            {"role": "user", "content": prompt},
        ]
            
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        kw_args = {"max_new_tokens": 32768, "output_scores": True, "return_dict_in_generate": True}
        if do_sample:
            kw_args["do_sample"] = True
            kw_args["temperature"] = 0.6 
        else:
            kw_args["do_sample"] = False
        if self.is_anyprecision:
            kw_args["precision"] = precision 

        generated_output = self.model.generate(
            **model_inputs,
            **kw_args,
        )

        generated_ids = generated_output.sequences
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return response, len(generated_ids[0])

if __name__ == "__main__":
    agent = Agent(f"{data_root}/QwQ-32B", is_anyprecision=False, device="cuda:7")
    response = agent("What is the capital of France?")
    print(response)