from modelscope import AutoModelForCausalLM, AutoTokenizer
import torch
from openai import OpenAI
import numpy as np
import sys
from auto_gptq import AutoGPTQForCausalLM
import os

class LLamaPredictor:
    
    def __init__(self, model_name="meta-llama/Llama-3.1-8B-Instruct", quantize = False, **kwargs):
#        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential")
        if 'GPTQ' not in model_name:
            self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="balanced")
        else:
            self.model = AutoGPTQForCausalLM.from_quantized(
                os.path.join('../models',model_name),
                torch_dtype=torch.float16,
                use_safetensors=True,
                use_triton=True, 
                # use_marlin = True,
                # low_cpu_mem_usage=True,
                device_map="sequential",
                # device='cuda:0'
                )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
    @torch.no_grad()
    def predict(self, messages, **kwargs):
        
        input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        # print(input_texts)
        
        model_inputs = self.tokenizer(input_texts, return_tensors="pt",add_special_tokens=False).to('cuda:0')
        res = self.model.generate(
            **model_inputs,
            max_new_tokens=1,
            output_logits=True,
            return_dict_in_generate=True,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        return res

class QwenPredictor:
    
    def __init__(self, model_name="Qwen/Qwen2.5-1.5B-Instruct", **kwargs):
        super().__init__(**kwargs)
#        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="sequential",attn_implementation='flash_attention_2')
        self.model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="balanced",attn_implementation='flash_attention_2')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    @torch.no_grad()
    def predict(self, messages, **kwargs):
        input_texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        # print(input_texts)
        
        model_inputs = self.tokenizer(input_texts, return_tensors="pt",add_special_tokens=False).to('cuda:0')
        res = self.model.generate(
            **model_inputs,
            max_new_tokens=1,
            output_logits=True,
            return_dict_in_generate=True,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return res



