import re
import requests
import json
import time
import openai
import random
import asyncio
from openai import OpenAI
from typing import Optional
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf


class InferHandler:
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

    def call_qwen_model(self, query: str, assistant_input: Optional[str] = None) -> str:
        Infer_cfg = self.cfg.Infer.qwen_model
        client = OpenAI(
            base_url=Infer_cfg.base_url,
            api_key=Infer_cfg.api_key,
        )
        try:
            if not assistant_input:
                assistant_input = "You are a helpful assistant."
            response_stream = client.chat.completions.create(
                model=Infer_cfg.model_name,
                messages=[
                    {"role": "system", "content": assistant_input},
                    {"role": "user", "content": query},
                ],
                temperature=round(
                    random.uniform(Infer_cfg.star_temperate, Infer_cfg.end_temperate),
                    Infer_cfg.temperature_decimal_places
                ),
                max_tokens=Infer_cfg.max_tokens,
                stream=True
            )
            output = ""
            for chunk in response_stream:
                token = chunk.choices[0].delta.content or ""
                output += token
                if "</search>" in output:
                    pos = output.find("</search>")
                    output = output[:pos + len("</search>")]
                    break
            return output
        except Exception as e:
            print(e)
            self.call_qwen_model(query, assistant_input)
            return ""

    def closestream_call_qwen_model(self, query: str, assistant_input: Optional[str] = None) -> str:
        cfg = self.cfg.Infer.qwen_model
        base_urls = ["http://10.221.105.108:35488/v1"]
        api_keys =  ["auto-deploy-key"]

        pairs = [(u, k) for u in base_urls for k in api_keys]
        random.shuffle(pairs)

        system_prompt = assistant_input or "You are a helpful assistant."
        attempt = 0

        try:
            while True:  # 死循环，直到成功 return 或用户中断
                for base_url, api_key in pairs:
                    attempt += 1
                    try:
                        client = OpenAI(base_url=base_url, api_key=api_key)
                        model_list = ["auto-deployed-model"]
                        model_name = random.choice(model_list)
                        # print(f"select model_name:{model_name}")
                        response = client.chat.completions.create(
                            model=model_name,
                            messages=[
                                {"role": "system", "content": system_prompt},
                                {"role": "user",   "content": query},
                            ],
                            temperature=0.0,
                            top_p=1.0,
                            stream=False,
                            max_tokens=cfg.max_tokens,
                        )

                        choices = None
                        if hasattr(response, "choices"):
                            choices = response.choices
                        elif isinstance(response, dict):
                            choices = response.get("choices")

                        if not choices or len(choices) == 0:
                            raise ValueError("response has no choices")

                        first = choices[0]
                        output = None

                        # dict-like
                        if isinstance(first, dict):
                            msg = first.get("message")
                            if isinstance(msg, dict):
                                output = msg.get("content") or msg.get("text")
                            else:
                                output = first.get("text") or first.get("message")
                        else:
                            # object-like
                            msg_obj = getattr(first, "message", None)
                            if msg_obj is not None:
                                output = getattr(msg_obj, "content", None) or getattr(msg_obj, "text", None)
                            if not output:
                                output = getattr(first, "text", None)

                        if not output:
                            raise ValueError("empty output from model")
                        return output

                    except KeyboardInterrupt:
                        print("KeyboardInterrupt: 用户中断调用。")
                        raise

                    except Exception as e:
                        print(f"[attempt {attempt}] 调用失败: base_url={base_url} api_key={api_key} model_name={model_name} error={e}")
                random.shuffle(pairs)

        except Exception:
            raise

    def process_prompt(self, args_infer):
        res = {}
        idx, one_query, assistant_input, Infer_cfg = args_infer
        res["query"] = one_query
        res["reply"] = self.closestream_call_qwen_model(one_query, assistant_input)
        res["idx"]= idx 
        if res["reply"]:
            return res
        return None
    
    def call_qwen_model(self, query: str, assistant_input: Optional[str] = None):
        try:
            args_infer = (0, query, assistant_input, self.cfg.Infer.qwen_model)
            
            result = self.process_prompt(args_infer)
            if result is None:
                return {}
                
            return result
            
        except Exception as e:
            print(f"Error calling Qwen model: {e}")
            return ""
    
    def batch_call_qwen_model(self, query_list: list, assistant_input: Optional[str] = None) -> list:
        from tqdm.contrib.concurrent import process_map
        args_infer = [(idx, one_query, assistant_input, self.cfg.Infer.qwen_model) for idx, one_query in enumerate(query_list)]
        max_workers = 100
        chunksize =16
        results = process_map(self.process_prompt, args_infer, max_workers = max_workers, chunksize = chunksize)
        results = [result for result in results if result is not None]
        results.sort(key=lambda x: x['idx'])
        return results

def get_infer_handler(config_path: str = "../llm_infer/config", config_name: str = "llm_config") -> InferHandler:
    with initialize(config_path=config_path, version_base="1.1"):
        cfg = compose(config_name=config_name)
    return InferHandler(cfg)

if __name__ == "__main__":
    handler = get_infer_handler()
    query = """who are you?"""
    assistant_input = ""
    reply = handler.closestream_call_qwen_model(query, assistant_input)
    print(reply)
    
