from modelscope import AutoModelForCausalLM, AutoTokenizer
import torch
from openai import OpenAI
import numpy as np
import sys
from auto_gptq import AutoGPTQForCausalLM
import os

class LLamaPredictor:
    
    def __init__(self, model_name="meta-llama/Llama-3.1-8B-Instruct", quantize = False, **kwargs):
#        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential")
        if 'GPTQ' not in model_name:
            self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="balanced")
        else:
            self.model = AutoGPTQForCausalLM.from_quantized(
                os.path.join('../models',model_name),
                torch_dtype=torch.float16,
                use_safetensors=True,
                use_triton=True, 
                # use_marlin = True,
                # low_cpu_mem_usage=True,
                device_map="sequential",
                # device='cuda:0'
                )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
    @torch.no_grad()
    def predict(self, messages, **kwargs):
        
        input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        # print(input_texts)
        
        model_inputs = self.tokenizer(input_texts, return_tensors="pt",add_special_tokens=False).to('cuda:0')
        res = self.model.generate(
            **model_inputs,
            max_new_tokens=1,
            output_logits=True,
            return_dict_in_generate=True,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        return res

class QwenPredictor:
    
    def __init__(self, model_name="Qwen/Qwen2.5-1.5B-Instruct", **kwargs):
        super().__init__(**kwargs)
#        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential",attn_implementation='flash_attention_2')
        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="balanced",attn_implementation='flash_attention_2')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    @torch.no_grad()
    def predict(self, messages, **kwargs):
        input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        # print(input_texts)
        
        model_inputs = self.tokenizer(input_texts, return_tensors="pt",add_special_tokens=False).to('cuda:0')
        res = self.model.generate(
            **model_inputs,
            max_new_tokens=1,
            output_logits=True,
            return_dict_in_generate=True,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return res





class OpenAIAPIPredictor:
    labels = ['A', 'B', 'C', 'D']
    """
    OpenAI API predictor for multiple choice questions.
    """
    def __init__(self, model_name='gpt-4o-2024-11-20', **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.model = OpenAI(base_url='')

    @classmethod
    def get_probs(cls, response):
        res = np.zeros(len(cls.labels))
        labels = {label:idx for idx, label in enumerate(cls.labels)}
        res[:] = -100
        probs = response.choices[0].logprobs.content[0].top_logprobs
        for i in range(len(res)):
            if probs[i].token in cls.labels:
                res[labels[probs[i].token]] = probs[i].logprob
        return res


    def predict(self, messages, logits = True, **kwargs):
        for _ in range(10):
            try:
                response = self.model.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                    temperature=1,
                    max_tokens=1,
                    top_p=1.0,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    logprobs=True,
                    top_logprobs=10
                )
                return self.get_probs(response)
                break
            except :
                pass
        if response is None:
            print(f"Error in response for text: {messages}", file=sys.stderr)
            return None
        else:
            print(f"Error in response for text: {messages}, and the response is {response}", file=sys.stderr)
            return None
        
