import re
import json
from typing import Dict, Any
from openai import OpenAI


class UserSimulator:
    def __init__(
        self,
        client: OpenAI,
        model_name: str,
        temperature: float = 0.0,
        max_tokens: int = 2048,
        max_retry: int = 3,   # 新增：最大重试次数
    ):
        self.client = client
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.max_retry = max_retry

    @staticmethod
    def _extract_json(text: str) -> Dict[str, Any]:
        """
        尽最大努力从模型输出中提取 JSON
        """
        if not text:
            raise ValueError("Empty response from model")

        text = text.strip()

        # 1️⃣ 优先解析 ```json ``` 块
        code_block = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text)
        if code_block:
            return json.loads(code_block.group(1))

        # 2️⃣ 尝试提取第一个 {...}
        brace_match = re.search(r"(\{[\s\S]*\})", text)
        if brace_match:
            return json.loads(brace_match.group(1))

        # 3️⃣ 最后直接整体解析（可能直接就是 JSON）
        return json.loads(text)

    def generate(self, prompt: str, temperature: float = 0.0) -> Dict[str, Any]:
        data: Dict[str, Any] = {}

        for attempt in range(1, self.max_retry + 1):
            try:
                if attempt > 1:
                    temperature = 0.7  # 注意：这里应该改的是入参 temperature

                resp = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=temperature,
                    max_tokens=self.max_tokens,
                )

                raw = resp.choices[0].message.content or ""
                data = self._extract_json(raw)

                # 🔴 字段校验：没有就触发重试
                if "instruction_ids" not in data or "user_query" not in data:
                    raise ValueError("Missing required fields: instruction_ids or user_query")

                # ✅ 到这里说明：JSON 正确 + 字段齐全
                return data

            except (json.JSONDecodeError, ValueError) as e:
                last_error = e
                print(f"[generate retry {attempt}/{self.max_retry}] failed: {e}")

        # 🚨 所有 retry 都失败后的兜底返回
        data.setdefault("instruction_ids", [])
        data.setdefault("user_query", "")
        return data

        
