import re
import requests
import json
import time
import openai
import random
import asyncio
# import jsonlines
from openai import OpenAI
from typing import Optional
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf
from typing import List, Dict, Any, Tuple, Union


class InferHandler:
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

    def closestream_call_qwen_model(self, query: str, assistant_input: Optional[str] = None) -> str:
        cfg = self.cfg.Infer.qwen_model
        client = OpenAI(
            base_url=cfg.base_url,
            api_key=cfg.api_key,
        )
        try:
            if not assistant_input:
                assistant_input = "You are a helpful assistant."
            
            response = client.chat.completions.create(
                model=cfg.model_name,
                messages=[
                    {"role": "system", "content": assistant_input},
                    {"role": "user",   "content": query},
                ],
                temperature=0.0,        
                top_p=1.0,              
                stream=False,          
                max_tokens=cfg.max_tokens,
            )
            
            output = response.choices[0].message.content
            return output

        except Exception as e:
            return ""

    def process_prompt(self, args_infer):
        res = {}
        idx, one_query, assistant_input, Infer_cfg = args_infer
        res["query"] = one_query
        res["reply"] = self.closestream_call_qwen_model(one_query, assistant_input)
        res["idx"]= idx 
        if res["reply"]:
            return res
        return None
    
    def batch_call_qwen_model(self, query_list: list, assistant_input: Optional[str] = None) -> list:
        from tqdm.contrib.concurrent import process_map
        args_infer = [(idx, one_query, assistant_input, self.cfg.Infer.qwen_model) for idx, one_query in enumerate(query_list)]
        max_workers = 100
        chunksize = 16
        results = process_map(self.process_prompt, args_infer, max_workers = max_workers, chunksize = chunksize)
        results = [result for result in results if result is not None]
        results.sort(key=lambda x: x['idx'])
        return results
        
        
def get_batch_infer(config_path: str = "../../tool/tools/config", config_name: str = "eval_search") -> InferHandler:
    with initialize(config_path=config_path, version_base="1.1"):
        cfg = compose(config_name=config_name)
    return InferHandler(cfg)
    
