"""
Orchestra Core - 轨迹格式化器

提供可扩展的轨迹格式化机制:
1. ActionFormatter: 格式化不同类型的 action
2. ObservationFormatter: 格式化不同类型的 observation
3. TraceFormatter: 组合格式化完整轨迹

设计原则:
- 通过注册机制支持自定义格式化
- 提供合理的默认格式化
- 支持截断长输出
"""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Protocol


class StepLike(Protocol):
    """Step 协议，兼容 StepRecord 结构"""
    action: Dict[str, Any]
    observation: Any
    reward: float
    done: bool
    info: Dict[str, Any]


class ActionFormatter(ABC):
    """Action 格式化器基类"""
    
    @property
    @abstractmethod
    def action_type(self) -> str:
        """返回此格式化器处理的 action 类型"""
        ...
    
    @abstractmethod
    def format(self, params: Dict[str, Any], max_len: int = 100) -> str:
        """格式化 action 为可读字符串"""
        ...


class ObservationFormatter(ABC):
    """Observation 格式化器基类"""
    
    @abstractmethod
    def can_format(self, obs: Dict[str, Any]) -> bool:
        """判断此格式化器是否可以处理该 observation"""
        ...
    
    @abstractmethod
    def format(self, obs: Dict[str, Any], max_len: int = 300) -> tuple[str, str]:
        """
        格式化 observation
        
        Returns:
            tuple[str, str]: (状态行, 输出内容)
        """
        ...


# ============== 通用格式化器 ==============

class ExecuteActionFormatter(ActionFormatter):
    """execute action 格式化器"""
    
    @property
    def action_type(self) -> str:
        return "execute"
    
    def format(self, params: Dict[str, Any], max_len: int = 100) -> str:
        cmd = params.get("command", "")[:max_len]
        return f'execute(command="{cmd}")'


class FinishActionFormatter(ActionFormatter):
    """finish action 格式化器"""
    
    @property
    def action_type(self) -> str:
        return "finish"
    
    def format(self, params: Dict[str, Any], max_len: int = 100) -> str:
        status = params.get("status", "done")
        msg = params.get("message", "")[:60]
        return f'finish(status="{status}", msg="{msg}")'


class SubmitActionFormatter(ActionFormatter):
    """submit action 格式化器"""
    
    @property
    def action_type(self) -> str:
        return "submit"
    
    def format(self, params: Dict[str, Any], max_len: int = 100) -> str:
        return "submit()"


class ExitCodeObservationFormatter(ObservationFormatter):
    """exit_code + output 格式的 observation 格式化器"""
    
    def can_format(self, obs: Dict[str, Any]) -> bool:
        return "exit_code" in obs
    
    def format(self, obs: Dict[str, Any], max_len: int = 300) -> tuple[str, str]:
        exit_code = obs.get("exit_code", "N/A")
        output = str(obs.get("output", ""))
        return f"exit_code={exit_code}", output


class SuccessObservationFormatter(ObservationFormatter):
    """success + output 格式的 observation 格式化器"""
    
    def can_format(self, obs: Dict[str, Any]) -> bool:
        return "success" in obs
    
    def format(self, obs: Dict[str, Any], max_len: int = 300) -> tuple[str, str]:
        success = obs.get("success", False)
        output = str(obs.get("output", obs.get("error", "")))
        return f"success={success}", output


class FallbackActionFormatter(ActionFormatter):
    """通用 fallback action 格式化器"""
    
    def __init__(self, action_type: str = "unknown"):
        self._action_type = action_type
    
    @property
    def action_type(self) -> str:
        return self._action_type
    
    def format(self, params: Dict[str, Any], max_len: int = 100) -> str:
        param_keys = list(params.keys())[:3]
        return f'{self._action_type}({param_keys})'


class FallbackObservationFormatter(ObservationFormatter):
    """通用 fallback observation 格式化器"""
    
    def can_format(self, obs: Dict[str, Any]) -> bool:
        return True
    
    def format(self, obs: Dict[str, Any], max_len: int = 300) -> tuple[str, str]:
        return "", str(obs)


# ============== TraceFormatter 主类 ==============

class TraceFormatter:
    """轨迹格式化器"""
    
    def __init__(self):
        self._action_formatters: Dict[str, ActionFormatter] = {}
        self._obs_formatters: List[ObservationFormatter] = []
        self._fallback_obs_formatter = FallbackObservationFormatter()
    
    def register_action_formatter(self, formatter: ActionFormatter) -> "TraceFormatter":
        """注册 action 格式化器"""
        self._action_formatters[formatter.action_type] = formatter
        return self
    
    def register_obs_formatter(self, formatter: ObservationFormatter) -> "TraceFormatter":
        """注册 observation 格式化器"""
        self._obs_formatters.append(formatter)
        return self
    
    def format_action(self, action: Dict[str, Any], max_len: int = 100) -> str:
        """格式化单个 action"""
        action_type = action.get("action", "unknown")
        params = action.get("params", {})
        
        formatter = self._action_formatters.get(action_type)
        if formatter:
            return formatter.format(params, max_len)
        return FallbackActionFormatter(action_type).format(params, max_len)
    
    def format_observation(self, obs: Any, max_len: int = 300) -> tuple[str, str]:
        """格式化单个 observation"""
        if not isinstance(obs, dict):
            return "", str(obs)
        
        for formatter in self._obs_formatters:
            if formatter.can_format(obs):
                return formatter.format(obs, max_len)
        
        return self._fallback_obs_formatter.format(obs, max_len)
    
    def format_trace(self, trace: List[StepLike], max_output_len: int = 300) -> str:
        """格式化完整轨迹"""
        if not trace:
            return "No steps executed"
        
        lines = []
        for i, step in enumerate(trace, 1):
            # 格式化 action
            action_str = self.format_action(step.action)
            lines.append(f"Step {i}: {action_str}")
            
            # 格式化 observation
            status_line, output = self.format_observation(step.observation, max_output_len)
            if status_line:
                lines.append(f"  → {status_line}")
            
            # 截断 output
            if len(output) > max_output_len:
                output = output[:max_output_len] + f"...[+{len(output)-max_output_len} chars]"
            output = output.replace("\n", " ").strip()
            lines.append(f"  → output: {output}")
            lines.append("")
        
        return "\n".join(lines)


# ============== 预置格式化器工厂 ==============

def create_default_formatter() -> TraceFormatter:
    """创建默认格式化器"""
    return (
        TraceFormatter()
        .register_action_formatter(ExecuteActionFormatter())
        .register_action_formatter(FinishActionFormatter())
        .register_action_formatter(SubmitActionFormatter())
        .register_obs_formatter(ExitCodeObservationFormatter())
        .register_obs_formatter(SuccessObservationFormatter())
    )
