"""
Supervisor Agent - Fast-Slow Feedback 的 Slow Agent 实现

负责:
1. 感知批评 (Perceptual Critique) - 分析合成音频质量
2. 语义-声学一致性检查
3. 触发强度微调或条件重置
"""

import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class DeviationType(Enum):
    """偏差类型枚举"""
    NONE = "none"
    EMOTION_TOO_WEAK = "emotion_too_weak"
    EMOTION_TOO_STRONG = "emotion_too_strong"
    INCORRECT_EMOTION = "incorrect_emotion"
    SPEAKER_MISMATCH = "speaker_mismatch"
    QUALITY_ISSUE = "quality_issue"


@dataclass
class CritiqueResult:
    """感知批评结果"""
    deviation_type: DeviationType
    confidence: float
    critique_text: str
    suggested_action: str
    alpha_adjustment: Optional[float] = None


class SupervisorAgent:
    """
    Supervisor Agent - Slow Agent 实现
    
    执行高层感知批评，检测语义-声学不一致并触发修正:
    - 情感过强/过弱 → 触发 Fast Agent 重新校准 α
    - 情感类型错误 → 触发 DAC 重新检索
    """

    def __init__(self, llm_client=None, device: str = "cpu"):
        """
        初始化 Supervisor Agent
        
        Args:
            llm_client: LLM 客户端 (如 Gemini) 用于高层推理
            device: 计算设备
        """
        self.llm_client = llm_client
        self.device = device
        self.critique_history: List[CritiqueResult] = []
        logger.info("SupervisorAgent 初始化完成")

    def analyze_audio(self, audio_path: str, target_emotion: str = None, 
                     target_intensity: float = None) -> CritiqueResult:
        """
        分析合成音频并生成感知批评
        
        Args:
            audio_path: 合成音频路径
            target_emotion: 目标情感标签
            target_intensity: 目标强度 (0-1)
            
        Returns:
            CritiqueResult 批评结果
        """
        logger.info(f"分析音频: {audio_path}")
        
        # 如果有 LLM 客户端，使用 LLM 进行分析
        if self.llm_client:
            return self._analyze_with_llm(audio_path, target_emotion, target_intensity)
        
        # 否则返回默认结果
        return CritiqueResult(
            deviation_type=DeviationType.NONE,
            confidence=0.8,
            critique_text="音频质量符合预期",
            suggested_action="continue"
        )

    def _analyze_with_llm(self, audio_path: str, target_emotion: str, 
                         target_intensity: float) -> CritiqueResult:
        """使用 LLM 进行深度分析"""
        try:
            prompt = f"""请分析以下合成音频的情感表现:
目标情感: {target_emotion or '未指定'}
目标强度: {target_intensity or '未指定'}

请评估:
1. 情感类型是否正确
2. 情感强度是否合适
3. 是否存在质量问题

返回JSON格式:
{{
    "deviation_type": "none|emotion_too_weak|emotion_too_strong|incorrect_emotion",
    "confidence": 0.0-1.0,
    "critique": "详细批评",
    "action": "continue|adjust_alpha|re_retrieve",
    "alpha_adjustment": 0.0-2.0 (如需调整)
}}"""

            # 调用 LLM (此处为示例，需要具体实现)
            # response = self.llm_client.analyze_audio(audio_path, prompt)
            
            # 返回默认结果
            return CritiqueResult(
                deviation_type=DeviationType.NONE,
                confidence=0.8,
                critique_text="音频分析完成",
                suggested_action="continue"
            )
            
        except Exception as e:
            logger.error(f"LLM 分析失败: {e}")
            return CritiqueResult(
                deviation_type=DeviationType.NONE,
                confidence=0.5,
                critique_text=f"分析出错: {e}",
                suggested_action="continue"
            )

    def should_adjust_alpha(self, critique: CritiqueResult) -> bool:
        """判断是否需要调整 α"""
        return critique.deviation_type in [
            DeviationType.EMOTION_TOO_WEAK,
            DeviationType.EMOTION_TOO_STRONG
        ]

    def should_re_retrieve(self, critique: CritiqueResult) -> bool:
        """判断是否需要重新检索"""
        return critique.deviation_type == DeviationType.INCORRECT_EMOTION

    def get_alpha_adjustment(self, critique: CritiqueResult, current_alpha: float) -> float:
        """
        计算 α 调整值
        
        Args:
            critique: 批评结果
            current_alpha: 当前 α 值
            
        Returns:
            调整后的 α 值
        """
        if critique.alpha_adjustment is not None:
            return critique.alpha_adjustment
        
        if critique.deviation_type == DeviationType.EMOTION_TOO_WEAK:
            return min(2.0, current_alpha * 1.2)  # 增强 20%
        elif critique.deviation_type == DeviationType.EMOTION_TOO_STRONG:
            return max(0.0, current_alpha * 0.8)  # 减弱 20%
        
        return current_alpha

    def evaluate_batch(self, audio_paths: List[str], target_emotion: str = None,
                      target_intensity: float = None) -> List[CritiqueResult]:
        """
        批量评估多个音频
        
        Args:
            audio_paths: 音频路径列表
            target_emotion: 目标情感
            target_intensity: 目标强度
            
        Returns:
            CritiqueResult 列表
        """
        results = []
        for path in audio_paths:
            result = self.analyze_audio(path, target_emotion, target_intensity)
            results.append(result)
            self.critique_history.append(result)
        return results

    def select_best(self, audio_paths: List[str], critiques: List[CritiqueResult]) -> int:
        """
        从批量结果中选择最佳音频
        
        Args:
            audio_paths: 音频路径列表
            critiques: 对应的批评结果
            
        Returns:
            最佳音频的索引
        """
        if not critiques:
            return 0
        
        # 选择置信度最高且无偏差的结果
        best_idx = 0
        best_score = 0.0
        
        for i, c in enumerate(critiques):
            score = c.confidence
            if c.deviation_type != DeviationType.NONE:
                score *= 0.5  # 有偏差的结果降权
            if score > best_score:
                best_score = score
                best_idx = i
        
        return best_idx
