import ast
import time
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

import pandas as pd
import time
from tqdm import tqdm
import transformers
import torch
import json
import os
import argparse
from sklearn.metrics import confusion_matrix
import torch
import copy


class Qwen3Model():

    def __init__(self, model_size=None, device=None, port=None, enable_reasoning=True):
        openai_api_key = "EMPTY"
        openai_api_base = f"http://localhost:{port}/v1"
        self.enable_reasoning = True

        print(openai_api_base)
        self.client = OpenAI(
                            api_key=openai_api_key,
                            base_url=openai_api_base,
                        )

        models = self.client.models.list()
        self.model = models.data[0].id

    def __call__(self, message, enable_thinking=False):
        response = self.client.chat.completions.create(model=self.model, 
                                                    messages=message,
                                                    extra_body={"chat_template_kwargs": {"enable_thinking": self.enable_reasoning}},
                                                    timeout=3000,
                                                    max_tokens=8196)
            
        thinking = response.choices[0].message.reasoning_content
        raw_response = response.choices[0].message.content
        return thinking, raw_response