"""
Orchestra Core - MainAgent (主编排智能体)

MainAgent 是整个编排系统的大脑，负责:
1. 分析任务需求，制定执行策略
2. 选择合适的 SubAgent 和模型执行子任务
3. 根据执行结果决定下一步（继续、重试、完成）
4. 管理执行历史和上下文

核心设计:
- MainAgent 不直接执行任务，通过 delegate_task 工具委派
- 通过 subtask_history 追踪所有执行记录
- 支持模型名称遮蔽，避免对特定模型产生偏好
"""
from __future__ import annotations

import json
from collections import Counter
from typing import Any, Dict, List, Optional

from orchestra_core.interfaces import (
    BaseTool,
    BasicInfo,
    LLMInterface,
    PromptBuilder,
)


class MainAgent:
    """主编排智能体 - 管理任务分解和 SubAgent 调度"""
    
    def __init__(
        self,
        llm: LLMInterface,
        sub_models: List[str],
        tools: Optional[List[BaseTool]] = None,
        prompt_builder: Optional[PromptBuilder] = None,
        max_attempts: int = 10,
        mask_model_names: bool = True,
    ):
        """
        初始化 MainAgent
        
        Args:
            llm: LLM 接口实例
            sub_models: 可用的 SubAgent 模型列表
            tools: 可用工具列表（通常包含 delegate_task 和 submit/complete）
            prompt_builder: Prompt 构建器
            max_attempts: 最大尝试次数
            mask_model_names: 是否遮蔽模型名称（避免偏好）
        """
        if prompt_builder is None:
            raise ValueError("prompt_builder is required")
        
        self.llm = llm
        self.sub_models = sub_models
        self.tools = tools or []
        self.prompt_builder = prompt_builder
        self.max_attempts = max_attempts
        
        # 状态
        self.instruction: str = ""
        self.meta: Dict[str, Any] = {}
        self.attempt: int = 0
        self.context: str = ""
        self.history: List[Dict] = []
        self.task_entries: List[Dict] = []  # 结构化的子任务执行记录
        
        # 模型名称遮蔽配置
        self.mask_model_names = mask_model_names
        if mask_model_names:
            self.model_to_alias: Dict[str, str] = {
                model: f"model_{i+1}" for i, model in enumerate(sub_models)
            }
            self.alias_to_model: Dict[str, str] = {
                v: k for k, v in self.model_to_alias.items()
            }
            self.masked_sub_models = list(self.model_to_alias.values())
        else:
            self.model_to_alias = {m: m for m in sub_models}
            self.alias_to_model = {m: m for m in sub_models}
            self.masked_sub_models = sub_models
    
    def reset(self, env_info: BasicInfo) -> None:
        """重置 MainAgent 状态"""
        self.instruction = env_info.instruction
        self.meta = env_info.meta_data or {}
        self.attempt = 0
        self.context = ""
        self.history = []
        self.task_entries = []
    
    async def step(self, observation: Any, history: List, **kwargs) -> tuple:
        """
        执行一步决策
        
        Returns:
            tuple: (action_dict, raw_response)
        """
        self.attempt += 1
        
        # 生成子任务历史
        subtask_history = self._format_subtask_history()
        
        # 构建 prompt
        prompt = self.prompt_builder.build_prompt(
            instruction=self.instruction,
            meta=self.meta,
            tools_description=self._get_tools_description(),
            prior_context=self.context,
            attempt_index=self.attempt,
            max_attempts=self.max_attempts,
            sub_models=self.masked_sub_models,
            subtask_history=subtask_history,
            model_to_alias=self.model_to_alias if self.mask_model_names else None,
        )
        
        # 调用 LLM
        resp = await self.llm(prompt)
        decision = self._parse_json(resp)
        
        action_name = decision.get("action")
        params = decision.get("params", {})
        
        # 查找并执行工具
        tool = next((t for t in self.tools if t.name == action_name), None)
        if not tool:
            return {"action": "error", "error": f"Unknown action: {action_name}"}, resp
        
        # 解析模型别名（如果使用了遮蔽）
        if "model" in params and self.mask_model_names:
            alias = params["model"]
            if alias in self.alias_to_model:
                params["model"] = self.alias_to_model[alias]
        
        result = await tool(**params)
        
        # 更新上下文
        self._update_context(action_name, params, result)
        
        return {
            "action": action_name,
            "params": params,
            "result": result,
            "subtask_history": subtask_history,
        }, resp
    
    def _get_tools_description(self) -> str:
        """生成工具描述文本"""
        if not self.tools:
            return "No tools available."
        
        descriptions = []
        for tool in self.tools:
            desc = f"""
Tool Name: {tool.name}
Description: {tool.description}
Parameters: {json.dumps(tool.parameters, indent=2)}
"""
            descriptions.append(desc)
        return "\n".join(descriptions)
    
    def _parse_json(self, resp: str) -> Dict:
        """从 LLM 响应中解析 JSON"""
        import re
        
        s = resp.strip()
        
        # 尝试从 markdown 代码块中提取
        code_block_pattern = r'```(?:json)?\s*\n?([\s\S]*?)\n?```'
        match = re.search(code_block_pattern, s)
        if match:
            s = match.group(1).strip()
        else:
            # 尝试直接查找 JSON 对象
            start = s.find('{')
            end = s.rfind('}')
            if start != -1 and end != -1 and end > start:
                s = s[start:end + 1]
        
        return json.loads(s)
    
    def _update_context(self, action: str, params: Dict, result: Dict) -> None:
        """更新执行上下文"""
        summary = f"[{self.attempt}] {action}\n"
        
        if action == "delegate_task":
            finish = result.get("finish_result", {})
            if finish:
                summary += f"  Status: {finish.get('status')}\n"
                if finish.get('completed'):
                    summary += f"  Completed: {finish['completed']}\n"
                if finish.get('issues'):
                    summary += f"  Issues: {finish['issues']}\n"
                if finish.get('message'):
                    summary += f"  Message: {finish.get('message')}\n"
            else:
                summary += f"  Steps: {result.get('steps_taken', 0)}, Done: {result.get('done', False)}\n"
            
            # 添加到 task_entries
            finish_result = result.get('finish_result', {})
            if finish_result:
                entry_status = finish_result.get('status', 'partial')
                entry_message = finish_result.get('message', '')
                entry_completed = finish_result.get('completed', [])
                entry_issues = finish_result.get('issues', [])
            else:
                entry_status = 'partial'
                entry_message = 'SubAgent did not finish (max steps reached).'
                entry_completed = []
                entry_issues = ['SubAgent timeout']
            
            # 获取显示用的模型名（如果使用了遮蔽，显示别名）
            model_display = params.get('model', 'unknown')
            if self.mask_model_names and model_display in self.model_to_alias:
                model_display = self.model_to_alias[model_display]
            
            self.task_entries.append({
                "attempt": self.attempt,
                "status": entry_status,
                "instruction": params.get('task_instruction', 'N/A'),
                "model": model_display,
                "steps_taken": result.get('steps_taken', 0),
                "max_steps": result.get('statistics', {}).get('max_steps', 30),
                "cost": result.get('cost', 0),
                "message": entry_message,
                "completed": entry_completed,
                "issues": entry_issues,
                "trace_summary": result.get('trace_summary', ''),
            })
            
        elif action in ("submit", "complete"):
            summary += f"  Success: {result.get('success')}, Reward: {result.get('reward')}\n"
        
        self.context = summary + "\n" + self.context
        self.history.append({
            "attempt": self.attempt,
            "action": action,
            "result": result
        })
    
    def _format_subtask_history(self) -> str:
        """格式化子任务执行历史"""
        if not self.task_entries:
            return "No subtasks completed yet."
        
        lines = []
        done_count = 0
        all_completed = []
        all_issues = []
        
        for e in self.task_entries:
            emoji = "✅" if e["status"] == "done" else "⚠️"
            steps_info = f'{e.get("steps_taken", "?")}/{e.get("max_steps", 30)}'
            model_display = e.get("model", "?")
            
            entry_lines = [
                f'[Attempt {e["attempt"]}] {emoji} {e["status"]} | Model: {model_display} | Steps: {steps_info}',
                f'├─ Task: {e["instruction"]}',
            ]
            
            if e.get("message"):
                entry_lines.append(f'├─ Message: {e["message"]}')
            
            completed = e.get("completed", [])
            if completed:
                entry_lines.append(f'├─ ✅ Completed: {completed}')
                all_completed.extend(completed)
            
            issues = e.get("issues", [])
            if issues:
                entry_lines.append(f'├─ ❌ Issues: {issues}')
                all_issues.extend(issues)
            
            trace_summary = e.get("trace_summary", "")
            if trace_summary and trace_summary != "N/A":
                entry_lines.append(f'└─ Trace summary:\n{self._indent_text(trace_summary, "   ")}')
            else:
                entry_lines[-1] = entry_lines[-1].replace('├─', '└─')
            
            lines.append("\n".join(entry_lines))
            
            if e["status"] == "done":
                done_count += 1
        
        # 汇总
        summary_lines = ["---", f"Summary: {done_count}/{len(self.task_entries)} subtasks done"]
        if all_completed:
            summary_lines.append(f"✅ All successful steps: {all_completed}")
        if all_issues:
            summary_lines.append(f"❌ All issues encountered: {all_issues}")
        
        lines.append("\n".join(summary_lines))
        
        return "\n\n".join(lines)
    
    def _indent_text(self, text: str, indent: str = "   ") -> str:
        """给文本添加缩进"""
        return "\n".join(indent + line for line in text.strip().split("\n"))
    
    def get_usage_cost(self) -> float:
        """获取 LLM 使用成本"""
        return self.llm.get_usage_summary().get("total_cost", 0.0)
