from .text_predict import LLMSentimentPredictor
from openai import OpenAI
from modelscope import AutoModelForCausalLM, AutoTokenizer,GenerationConfig
import numpy as np
import torch
import sys

class QwenPredictor(LLMSentimentPredictor):
    
    def __init__(self, model_name="qwen/Qwen2.5-1.5B-Instruct", **kwargs):
        super().__init__(**kwargs)
        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential",attn_implementation='flash_attention_2')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoded_label = self.tokenizer.encode(self.labels, return_tensors="pt")

    @torch.no_grad()
    def _predict(self, texts, logits = True, **kwargs):
        results = []
        
        for text in texts:
            user_prompt = self.get_user_prompt(text)[0]
            messages =[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": user_prompt}
                ] 

            input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            # print(input_texts)
            
            model_inputs = self.tokenizer(input_texts, return_tensors="pt",add_special_tokens=False).to(self.model.device)
            res = self.model.generate(
                **model_inputs,
                max_new_tokens=1,
                output_logits=True,
                return_dict_in_generate=True,
                # temperature=0.0
            )
            results.append(res.logits[0][0][self.encoded_label].cpu())
        results =  torch.concatenate(results, axis=0)
        if logits:
            return results.numpy()
        
        results = torch.softmax(results, dim=1).cpu().numpy()
        return results
    



class LLamaPredictor(LLMSentimentPredictor):
    
    def __init__(self, model_name="LLama/Llama-3.1-8B-Instruct", quantize = False, **kwargs):
        super().__init__(**kwargs)

        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential")
        # self.model.generate = torch.compile(self.model.generate, mode="reduce-overhead", fullgraph=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoded_label = self.tokenizer.encode(self.labels, return_tensors="pt",  add_special_tokens=False)

    @torch.no_grad()
    def _predict(self, texts, logits=True,  **kwargs):
        results = []
        
        for text in texts:
            user_prompt = self.get_user_prompt(text)[0]
            messages =[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": user_prompt}
                ] 

            input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            # print(input_texts)
            
            model_inputs = self.tokenizer(input_texts, return_tensors="pt", add_special_tokens=False).to(self.model.device)
            res = self.model.generate(
                **model_inputs,
                max_new_tokens=1,
                output_logits=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
            results.append(res.logits[0][0][self.encoded_label].cpu())
        results =  torch.concatenate(results, axis=0)
        if logits:
            return results.numpy()
        results = torch.softmax(results, dim=1).cpu().numpy()
        return results
