"""
Orchestra Core - OrchestraRunner (编排运行器)

负责管理整个编排流程:
1. 创建 MainAgent 并初始化工具
2. 运行编排循环（MainAgent 决策 -> 工具执行 -> 结果收集）
3. 保存执行轨迹和结果

关键设计:
- MainAgent 通过 delegate_task 委派任务给 SubAgent
- SubAgent 执行后通过 finish 报告结果
- MainAgent 根据结果决定下一步或提交
"""
from __future__ import annotations

import asyncio
import csv
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

from orchestra_core.interfaces import (
    BasicInfo,
    Environment,
    LevelResult,
    LLMInterface,
    PromptBuilder,
    StepRecord,
)
from orchestra_core.main_agent import MainAgent
from orchestra_core.sub_runner import SubAgentRunner


class OrchestraRunner:
    """编排运行器 - 管理 MainAgent 的执行流程"""
    
    def __init__(
        self,
        main_model: str,
        sub_models: List[str],
        max_attempts: int = 10,
        prompt_builder: Optional[PromptBuilder] = None,
        trajectory_dir: Optional[Path] = None,
        csv_summary_path: Optional[Path] = None,
        llm_factory: Optional[Callable[[str], LLMInterface]] = None,
        mask_model_names: bool = True,
    ):
        """
        初始化编排运行器
        
        Args:
            main_model: MainAgent 使用的模型
            sub_models: SubAgent 可用的模型列表
            max_attempts: 最大尝试次数
            prompt_builder: Prompt 构建器
            trajectory_dir: 轨迹保存目录
            csv_summary_path: CSV 汇总文件路径
            llm_factory: LLM 实例工厂函数
            mask_model_names: 是否遮蔽模型名称
        """
        self.main_model = main_model
        self.sub_models = sub_models
        self.max_attempts = max_attempts
        self.prompt_builder = prompt_builder
        self.trajectory_dir = Path(trajectory_dir) if trajectory_dir else None
        self.csv_summary_path = Path(csv_summary_path) if csv_summary_path else None
        self.llm_factory = llm_factory
        self.mask_model_names = mask_model_names
        
        self._csv_lock = asyncio.Lock()
    
    async def run(
        self,
        env: Environment,
        tools_factory: Callable[[Environment, "SubAgentRunner"], List[Any]],
    ) -> LevelResult:
        """
        运行编排流程
        
        Args:
            env: 执行环境
            tools_factory: 工具工厂函数，接收 env 和 sub_runner，返回工具列表
        
        Returns:
            LevelResult: 执行结果
        """
        env_info = env.get_basic_info()
        
        main_agent = None
        
        try:
            # 创建 MainAgent
            main_info = BasicInfo(
                env_id=env_info.env_id,
                instruction=env_info.instruction,
                action_space="",  # MainAgent 不使用环境的 action_space
                max_steps=self.max_attempts,
                meta_data=env_info.meta_data,
            )
            
            # 创建 LLM 实例
            if not self.llm_factory:
                raise ValueError("llm_factory is required")
            main_llm = self.llm_factory(self.main_model)
            
            # 创建 SubAgentRunner
            sub_runner = SubAgentRunner()
            
            # 创建工具
            tools = tools_factory(env, sub_runner)
            
            # 创建 MainAgent
            main_agent = MainAgent(
                llm=main_llm,
                sub_models=self.sub_models,
                tools=tools,
                prompt_builder=self.prompt_builder,
                max_attempts=self.max_attempts,
                mask_model_names=self.mask_model_names,
            )
            main_agent.reset(main_info)
            
        except Exception as init_error:
            raise RuntimeError(f"Failed to initialize MainAgent: {init_error}") from init_error
        
        # 编排循环
        history = []
        total_reward = 0.0
        total_sub_cost = 0.0
        done = False
        level_result = None
        exception_occurred = None
        
        try:
            for attempt_idx in range(self.max_attempts):
                try:
                    action, resp = await main_agent.step(None, history)
                except Exception as step_error:
                    history.append(StepRecord(
                        observation={},
                        action={"action": "error", "error": str(step_error)},
                        reward=0.0,
                        raw_response=str(step_error),
                        done=False,
                        info={"error": str(step_error)},
                    ))
                    continue
                
                action_name = action.get("action")
                result = action.get("result", {})
                reward = result.get("reward", 0.0)
                step_done = result.get("done", False)
                is_submit = action_name == "submit"
                
                # 累计 SubAgent 成本
                if action_name == "delegate_task":
                    total_sub_cost += result.get("cost", 0.0)
                
                history.append(StepRecord(
                    observation={},
                    action=action,
                    reward=reward,
                    raw_response=resp,
                    done=step_done,
                    info=result,
                ))
                
                if is_submit:
                    total_reward = reward
                
                if step_done and is_submit:
                    done = True
                    break
            
            # 如果没有 submit，尝试强制 submit
            if not done:
                executor = getattr(env, "_executor", None)
                if executor and hasattr(executor, "run_tests"):
                    try:
                        result = await executor.run_tests()
                        if isinstance(result, tuple):
                            reward, _ = result
                        else:
                            reward = result
                        total_reward = float(reward or 0.0)
                        done = True
                    except Exception:
                        done = True
                        total_reward = 0.0
                        
        except Exception as e:
            exception_occurred = e
            
        finally:
            # 清理环境
            if hasattr(env, 'close'):
                try:
                    await env.close()
                except Exception:
                    pass
            
            # 构建结果
            try:
                usage = main_agent.llm.get_usage_summary() if main_agent else {}
                main_cost = usage.get("total_cost", 0.0)
                total_cost = main_cost + total_sub_cost
                
                level_result = LevelResult(
                    model=usage.get("model", self.main_model),
                    total_reward=total_reward,
                    steps=len(history),
                    done=done,
                    trace=history,
                    cost=total_cost,
                    input_tokens=usage.get("total_input_tokens", 0),
                    output_tokens=usage.get("total_output_tokens", 0),
                    timestamp=datetime.now().isoformat(),
                )
                
                # 保存轨迹
                if self.trajectory_dir:
                    self._save_trajectory(env_info, level_result, main_agent, history)
                
                # 保存 CSV
                if self.csv_summary_path:
                    await self._save_csv(env_info.env_id, level_result)
                    
            except Exception:
                pass
            
            if exception_occurred:
                raise exception_occurred
            
            return level_result
    
    def _save_trajectory(
        self,
        info: BasicInfo,
        result: LevelResult,
        main_agent: Optional[MainAgent],
        history: List[StepRecord]
    ) -> None:
        """保存执行轨迹"""
        try:
            self.trajectory_dir.mkdir(parents=True, exist_ok=True)
            
            attempts = []
            for i, record in enumerate(history):
                action_data = record.action
                action_name = action_data.get("action")
                result_data = action_data.get("result", {})
                
                attempt = {
                    "attempt": i + 1,
                    "subtask_history": action_data.get("subtask_history", ""),
                    "main_agent": {
                        "action": action_name,
                        "params": action_data.get("params", {}),
                        "raw_response": record.raw_response,
                    },
                }
                
                if action_name == "delegate_task":
                    attempt["sub_agent"] = {
                        "model": result_data.get("model"),
                        "steps": result_data.get("steps_taken", 0),
                        "cost": result_data.get("cost", 0.0),
                        "finish_result": result_data.get("finish_result"),
                        "trace_summary": result_data.get("trace_summary", ""),
                    }
                elif action_name == "submit":
                    attempt["submit_result"] = {
                        "success": result_data.get("success"),
                        "reward": result_data.get("reward"),
                    }
                
                attempts.append(attempt)
            
            trajectory = {
                "task_id": info.env_id,
                "instruction": info.instruction,
                "metadata": info.meta_data,
                "main_model": self.main_model,
                "sub_models": self.sub_models,
                "total_reward": result.total_reward,
                "success": result.total_reward > 0,
                "done": result.done,
                "total_attempts": len(history),
                "total_cost": result.cost,
                "timestamp": result.timestamp,
                "attempts": attempts,
            }
            
            filename = f"{info.env_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
            filepath = self.trajectory_dir / filename
            
            with filepath.open("w", encoding="utf-8") as f:
                json.dump(trajectory, f, indent=2, ensure_ascii=False)
                
        except Exception:
            pass
    
    async def _save_csv(self, task_id: str, result: LevelResult) -> None:
        """保存 CSV 汇总"""
        async with self._csv_lock:
            try:
                if not self.csv_summary_path:
                    return
                
                self.csv_summary_path.parent.mkdir(parents=True, exist_ok=True)
                
                fieldnames = ["task_id", "model", "success", "reward", "attempts", "cost", "timestamp"]
                
                need_header = not self.csv_summary_path.exists() or self.csv_summary_path.stat().st_size == 0
                if need_header:
                    with self.csv_summary_path.open("w", newline="", encoding="utf-8") as f:
                        writer = csv.DictWriter(f, fieldnames=fieldnames)
                        writer.writeheader()
                
                with self.csv_summary_path.open("a", newline="", encoding="utf-8") as f:
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writerow({
                        "task_id": task_id,
                        "model": result.model,
                        "success": result.total_reward > 0,
                        "reward": f"{result.total_reward:.4f}",
                        "attempts": result.steps,
                        "cost": f"{result.cost:.6f}",
                        "timestamp": result.timestamp,
                    })
            except Exception:
                pass
