# openai_vllm_orchestrator.py
from openai import OpenAI
from openai import APIError, RateLimitError, APITimeoutError, APIConnectionError, APIStatusError

import os, json, time, re
from copy import deepcopy

from LC_Agent.orchestrator import Orchestrator as BaseOrchestrator  # <-- adjust if name differs
from LC_Agent.utils import read_json

def parse_openai_response(resp):
    choice = resp.choices[0]
    finish_reason = choice.finish_reason
    msg = choice.message
    text = (msg.content or "").strip()           # assistant plain text
    tool_calls = msg.tool_calls or []            # list[ChatCompletionMessageToolCall]
    #print(f"[DEBUG] text: {text}")
    #print(f"[DEBUG] finish_reason: {finish_reason}")

    if finish_reason == "tool_calls" and tool_calls:
        call = tool_calls[0]                     # we allow one call per step
        #print(f"[DEBUG] tool call: {call}")
        return (
            text,                                # assistant response (no <thought>)
            call.function.name,                  # tool name
            json.loads(call.function.arguments), # dict of arguments
            call.id,                             # tool_call_id
            finish_reason,
        )

    return text, None, None, None, finish_reason


class OpenAIOrchestratorVLLM(BaseOrchestrator):
    """
    Adapter: inherit from your existing Orchestrator and override only
    the LLM-facing parts + unified msg deletion + graceful Ctrl-C exit.
    """

    def __init__(
        self,
        openai_cfg: dict,
        document_content: str,
        temperature: float,
        tokenizer,
        logger=None,
        max_context_exp: int = 32000,
        max_turns_exp: int = 50,
        max_output_tokens: int = 4096,
        tool_config_path: str = None,  # NEW field absorbed here
        system_prompt_name: str = None, # NEW field absorbed here
        *,
        topp: float = 0.95,           # NEW field absorbed here
        topk: int   = 20,            # NEW field absorbed here
        model_name: str = "Qwen3-8B-SLM"):

        # 1) Call the Claude base with dummy credentials (lazy client; safe)
        dummy_claude = {
            "CLAUDE_ACCESS_KEY": "",
            "CLAUDE_SECRET_KEY": "",
            "CLAUDE_REGION":     "us-east-1",
            "CLAUDE_MODEL_ID":   "placeholder"
        }

        super().__init__(
            claude_config    = dummy_claude,
            document_content = document_content,
            temperature      = temperature,
            tokenizer        = tokenizer,
            logger           = logger,
            max_context_exp  = max_context_exp,
            max_turns_exp    = max_turns_exp,
            max_output_tokens= max_output_tokens,
            tool_config_path = tool_config_path,
            system_prompt_name = system_prompt_name
        )

        # 2) OpenAI/vLLM client + sampling params
        self.openai_client = OpenAI(
            base_url=openai_cfg["base_url"],
            api_key=openai_cfg.get("api_key", "EMPTY"),
            timeout=120,
        )
        self.model_name = model_name
        self.topp = topp
        self.topk = topk
        # Base set these; we just keep references:
        # self.temperature, self.max_output_tokens, self.state_manager, self.full_history,
        # self.deleted_msg_ids, self.ctx_counter, self.system_prompt, self.tools, self.logger

        # Optional: track total API calls
        # self.api_call_counter = getattr(self, "api_call_counter", 0)

    # ------- Small helpers -------
    # def _get_system_prompt_text(self):
    #     from LC_Agent.prompts import CLAUDE_SYSTEM_PROMPT_QWEN, INF_SYSTEM_PROMPT, TRAIN_SYSTEM_PROMPT
    #     return TRAIN_SYSTEM_PROMPT
        # return CLAUDE_SYSTEM_PROMPT_QWEN

    # def _get_tool_config(self):
    #     # Get the path relative to the current file's directory
    #     if self.qwen_tool_path and os.path.isfile(self.qwen_tool_path):
    #         print(f"[INFO] Using custom tool config: {self.qwen_tool_path}")
    #         return read_json(self.qwen_tool_path)
    #     current_dir = os.path.dirname(os.path.abspath(__file__))
    #     # Navigate to root directory (adjust the number of '..' based on your file structure)
    #     root_dir = os.path.dirname(current_dir)
    #     tool_path = os.path.join(root_dir, "tools_qwen.json")
    #     return read_json(tool_path)

    # ------- LLM call override -------

    def _call_llm_api(self, messages):
        body_kwargs = {
            "model": self.model_name,      # must match vLLM --served-model-name
            "messages": messages,
            "tools": self.tools,
            "temperature": self.temperature,
            "top_p": self.topp,
            "max_tokens": self.max_output_tokens,
            "extra_body": {"repetition_penalty": 1.0, "top_k": self.topk},
        }

        tries, max_tries = 0, 3
        while True:
            try:
                resp = self.openai_client.chat.completions.create(**body_kwargs)
                self.api_call_counter += 1
                if getattr(self, "logger", None):
                    self.logger.log_api_call(body_kwargs, resp.model_dump(), self.api_call_counter)
                return resp
            except (APIError, RateLimitError, APITimeoutError, APIConnectionError, APIStatusError) as e:
                if tries >= max_tries:
                    raise
                wait = 5 * (2**tries)
                print(f"[API] {e} – retrying in {wait}s")
                time.sleep(wait)
                tries += 1

    # Claude-compatibility shim (if your base still calls these)
    def _call_claude_api(self, messages):
        return self._call_llm_api(messages)

    # ------- Parser override -------

    def _parse_llm_output(self, resp):
        # print('Calling correct parser.')
        return parse_openai_response(resp)


    # ------- Unified deletion-aware payload -------
    def _build_api_payload(self):
        """
        Build a list[dict] `messages` that conforms to the OpenAI / vLLM format.

        * role == "system"  ➜ first element, contains the long system prompt
        * user / assistant  ➜ strings; assistant may also include `tool_calls`
        * tool results      ➜ role == "tool" (must carry the matching tool_call_id)
        """
        messages = [
            {"role": "system", "content": self.system_prompt}
        ]

        notes_summary = (
            f"\n\n### Notes Summary\n<notes_summary>\n"
            f"{self.state_manager.get_notes_summary()}\n</notes_summary>"
        )

        for idx, msg in enumerate(self.full_history):
            role = msg.get("role")

            if role == "user":
                text = msg["content"] + (notes_summary if idx == 0 else "")
                messages.append({"role": "user", "content": text})

            elif role == "assistant":
                msg_id = msg["msg_id"]
                tool_calls = msg.get("tool_calls", [])
                if tool_calls:
                    normalized_tool_calls = []
                    for tc in tool_calls:
                        if hasattr(tc, "model_dump"):
                            normalized_tool_calls.append(tc.model_dump())
                        elif isinstance(tc, dict):
                            normalized_tool_calls.append(tc)
                        else:
                            normalized_tool_calls.append(json.loads(json.dumps(tc, default=str)))

                if msg_id in self.deleted_msg_ids:
                    tool_calls = msg.get("tool_calls") or []
                    if tool_calls:
                        # Normalize all tool calls (objects -> dicts) and build valid stubs
                        stubs = []
                        for tc in tool_calls:
                            if hasattr(tc, "model_dump"):
                                tcd = tc.model_dump()
                            elif isinstance(tc, dict):
                                tcd = tc
                            else:
                                tcd = json.loads(json.dumps(tc, default=str))

                            fn = tcd.get("function") or {}
                            name = fn.get("name") or ""

                            stubs.append({
                                "id": tcd.get("id"),
                                "type": "function",
                                "function": {
                                    "name": name,
                                    # arguments MUST be a JSON string
                                    "arguments": json.dumps(
                                        {"message": "Content has been deleted to save space."},
                                        ensure_ascii=False,
                                    ),
                                },
                            })

                        messages.append({
                            "role": "assistant",
                            "content": "Content has been deleted to save space.",
                            "tool_calls": stubs,
                        })
                    else:
                        messages.append({
                            "role": "assistant",
                            "content": "Content has been deleted to save space.",
                        })

                    continue

                # Not deleted messages
                raw_text = " ".join(
                    blk.get("text", "")
                    for blk in msg.get("content", [])
                    if blk.get("type") == "text"
                )
                # Remove ONLY the <think> and </think> tags (keep inner content)
                # cleaned_text = re.sub(r"</?think>", "", raw_text, flags=re.IGNORECASE).strip()
                cleaned_text = raw_text.strip()
                assistant_msg = {
                    "role": "assistant",
                    # Match OpenAI's acceptance of None for empty text (like your sample)
                    "content": (cleaned_text if cleaned_text else None),
                }
                # If this assistant turn contained tool calls, forward them verbatim
                if tool_calls:
                    assistant_msg["tool_calls"] = normalized_tool_calls
                messages.append(assistant_msg)

            elif role == "tool":
                msg_id = msg["msg_id"]
                msg_id_ia = msg["msg_id(invoking_assistant)"]
                tool_use_id = msg["tool_use_id"]
                tool_result_content = msg["content"]
                tool_result_content["msg_id"] = msg_id    # we need to ensure msg_id is included in the result
                tool_result_content["msg_id(invoking_assistant)"] = msg_id_ia
                tool_result_content_cp = deepcopy(tool_result_content)
                if msg_id in self.deleted_msg_ids:
                    tool_name = msg.get("tool_name", "unknown")
                    if tool_name not in ["nextChunk", "readChunk", "note"]:
                        print(f"[INFO] Attempting to delete {msg.get('tool_name', 'unknown')}")
                    tool_result_content_cp = {
                        "msg_id": msg_id,
                        "msg_id(invoking_assistant)": msg_id_ia,
                        "status": "success",
                        "message": "Content has been deleted to save space.",
                        "original_tool": msg.get("tool_name", "unknown")
                    }
                    if msg.get("tool_name") == "nextChunk":
                        tool_result_content_cp["reading_progress"] = msg["content"]["reading_progress"]
                messages.append(
                    {
                        "role": "tool",
                        "content": json.dumps(tool_result_content_cp, ensure_ascii=False),
                        "tool_call_id": tool_use_id,
                    }
                )
        return messages


    # ------- Tool execution override (only for deleteContext change) -------

    def _execute_tool(self, action, params):
        if action == "checkBudget":
            api_payload = self._build_api_payload()
            tokenized_messages = self.tokenizer.apply_chat_template(api_payload, tools=self.tools, add_generation_prompt=False, tokenize=True)
            # with open("./message_debug.txt", "a") as f:
            #     f.write(f"Tokenized messages: {formatted_messages}\n")
            conv_rounds = len(self.full_history) // 2
            message_len = len(tokenized_messages) + 1500
            return {
                "conv_rounds": conv_rounds,
                "available_tokens": self.max_context_exp - message_len - self.max_output_tokens,
                "available_rounds": self.max_turns - conv_rounds, # expected rounds to finish
            }
        if action == "deleteContext":
            msg_id = params.get("msg_id")
            if msg_id is None:
                return {"error": "msg_id is required"}
            idx, entry = self._resolve_msg_entry(int(msg_id))
            if entry is None:
                return {"error": f"msg_id {msg_id} not found"}
            role = entry.get("role")
            if role == "user":
                return {"error": "Deleting user messages is not supported"}
            elif role == "assistant" or role == "tool":
                self.deleted_msg_ids.add(int(msg_id))
                return {"status": "success", "deleted_msg_id": int(msg_id), "deleted_role": role}
            return {"error": f"Unsupported role '{role}' for deletion"}
        if hasattr(self.tool_library, action):
            return getattr(self.tool_library, action)(params or {})
        return {"error": f"Tool '{action}' not found."}

    # ---------- Built-in: save trajectory & final answer ----------
    def _sanitize_for_json(self, obj):
        """Recursively convert SDK / complex objects to plain JSON-safe types."""
        # 1) Try model_dump()/dict()
        if hasattr(obj, "model_dump"):
            obj = obj.model_dump()
        elif hasattr(obj, "dict"):
            obj = obj.dict()

        # 2) Recurse containers
        if isinstance(obj, dict):
            return {k: self._sanitize_for_json(v) for k, v in obj.items()}
        if isinstance(obj, (list, tuple, set)):
            return [self._sanitize_for_json(v) for v in obj]

        # 3) Handle bytes-like
        if isinstance(obj, (bytes, bytearray)):
            try:
                return obj.decode("utf-8", errors="replace")
            except Exception:
                return str(obj)

        # 4) Last-resort: ensure jsonable via round-trip
        try:
            json.dumps(obj, ensure_ascii=False)
            return obj
        except TypeError:
            return str(obj)

    # ------- Save trajectory override (adds deleted ids) -------
    def save_trajectory(self, out_dir="logs", filename=None, correct_answer=None, meta_info=None):
        snapshot = {
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()),
            "model": getattr(self, "model_name", "Qwen3-8B-Agentic"),
            "session_notes_summary": self.state_manager.get_notes_summary(),
            "api_call_count": self.api_call_counter,
            "final_answer": self._extract_final_answer(),
            "correct_answer": correct_answer,
            "full_history": self._sanitize_for_json(self.full_history),
            "deleted_msg_ids": sorted(self.deleted_msg_ids),
            "meta_info": meta_info,
        }

        os.makedirs(out_dir or ".", exist_ok=True)
        if not filename:
            ts = time.strftime("%Y%m%d_%H%M%S", time.localtime())
            filename = f"trajectory_{ts}.json"
        path = os.path.join(out_dir, filename)

        with open(path, "w", encoding="utf-8") as f:
            json.dump(snapshot, f, ensure_ascii=False, indent=2)
        print(f"[INFO] Trajectory saved to: {path}")
        return path

    # ------- Wrap base run with Ctrl-C handling only -------
    def run(self, user_query, max_turns_to_fail=80):
        """
        Self-contained OpenAI/vLLM loop so base Orchestrator never touches
        the raw SDK response (avoids resp.get() errors).
        """
        # Ensure a first user turn with msg_id
        self.full_history.append({"role": "user", "content": user_query})
        self.ctx_counter = 0
        turn = 0
        try:
            while turn <= max_turns_to_fail: 
                print(f"\n--- Round {turn} (Max {max_turns_to_fail} rounds, expected within {self.max_turns} rounds) ---")
                api_payload = self._build_api_payload()
                try:
                    resp = self._call_llm_api(api_payload)
                except Exception as e:
                    err = f"LLM API failed after retries: {type(e).__name__}: {e}"
                    print("[ERROR]", err)
                    self.full_history.append({
                        "role": "tool",
                        "content": {"status": "error", "message": err},
                        "msg_id": self.ctx_counter + 1,
                        "msg_id(invoking_assistant)": self.ctx_counter,
                        "tool_use_id": "api_failure",
                        "tool_name": "finish"
                    })
                    return api_payload
                
                self.ctx_counter += 1
                thought, action, params, tool_use_id, stop_reason = self._parse_llm_output(resp)
                msg_id = self.ctx_counter

                self.full_history.append({
                    "role": "assistant",
                    "content": [{"type": "text", "text": resp.choices[0].message.content or ""}],
                    "tool_calls": resp.choices[0].message.tool_calls,  # <-- keep tool_calls for replay
                    "msg_id": msg_id
                })
                print("[RUN] Assistant:", thought)

                if stop_reason == 'tool_calls':
                    print(f"[RUN] Assistant action: Call tool `{action}`, parameters: {params}")
                    # self.tool_call_counter += 1
                    
                    result = self._execute_tool(action, params)
                    self.ctx_counter += 1
                    msg_id_tool = self.ctx_counter
                    self.full_history.append({
                        "role": "tool", 
                        "content": result, 
                        "msg_id": msg_id_tool,
                        "msg_id(invoking_assistant)": msg_id,
                        "tool_use_id": tool_use_id,
                        "tool_name": action
                    })
                    
                    # 限制输出长度
                    result_preview = json.dumps(result, ensure_ascii=False)
                    if len(result_preview) > 200:
                        result_preview = result_preview[:200] + "..."
                    print(f"[RUN] Tool result (ID: {msg_id_tool}): {result_preview}")

                    if action == "finish":
                        print(f"\n--- Final Answer --- \n{result.get('final_answer', 'No final answer provided.')}")
                        break
                
                else:
                    print(f"[INFO] Process terminated due to stop_reason '{stop_reason}'.")
                    break

                turn += 1
            
            if turn > self.max_turns:
                print(f"[INFO] Reached max rounds {self.max_turns}, stopping execution.")
            return self._build_api_payload() # full api payload for logging

        except KeyboardInterrupt:
            print("\n[INFO] Interrupted by user (Ctrl-C). Exiting gracefully...")
            # self.save_trajectory(out_dir="logs", filename="trajectory_interrupted.json")
            return self._build_api_payload()

    # ------- Minimal extract_final_answer passthrough if base has one -------
    # (keeps behavior but skips deleted assistant turns)
    def _extract_final_answer(self):
        # Prefer a finish tool result if present
        for msg in reversed(self.full_history):
            if msg.get("role") == "tool" and msg.get("tool_name") == "finish":
                content = msg.get("content", {})
                if isinstance(content, dict) and "final_answer" in content:
                    return content.get("final_answer")
        # Else the last non-deleted assistant text
        for msg in reversed(self.full_history):
            if msg.get("role") == "assistant" and msg.get("msg_id") not in self.deleted_msg_ids:
                text = ""
                for blk in (msg.get("content") or []):
                    if isinstance(blk, dict) and blk.get("type") == "text":
                        text += blk.get("text") or ""
                if (text := text.strip()):
                    return text
        return ""
