"""
Voice Clone Agent - 完整的语音克隆流程控制

集成:
1. DAC (Retrieval + Synthesis)
2. Fast-Slow Feedback (Fast Agent + Supervisor)
3. 迭代生成和选择
"""

import os
import logging
from typing import Optional, List, Dict, Any, Callable
from pathlib import Path
from datetime import datetime

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 情感顺序: [happy, angry, sad, afraid, disgusted, melancholic, surprised, calm]
EMO_ORDER = ["happy", "angry", "sad", "afraid", "disgusted", "melancholic", "surprised", "calm"]

EMO_KEY_NORMALIZE = {
    "joy": "happy", "happiness": "happy", "fear": "afraid",
    "neutral": "calm", "disgust": "disgusted", "surprise": "surprised",
    "melancholy": "melancholic", "depression": "melancholic"
}


def normalize_emotion_keys(emodict: Dict[str, float]) -> Dict[str, float]:
    """标准化情感键"""
    out = {}
    for k, v in emodict.items():
        k2 = str(k).strip().lower()
        k2 = EMO_KEY_NORMALIZE.get(k2, k2)
        out[k2] = float(v)
    return out


def emotion_dict_to_vector(emodict: Dict[str, float], normalize: bool = True) -> List[float]:
    """情感字典转向量"""
    emodict = normalize_emotion_keys(emodict)
    vec = [float(emodict.get(k, 0.0)) for k in EMO_ORDER]
    if normalize:
        s = sum(vec)
        if s > 0:
            vec = [x / s for x in vec]
    return vec


class VoiceCloneAgent:
    """
    语音克隆代理
    
    编排完整的 TTS 生成流程:
    1. 分析用户需求
    2. 检索情感音频样本 (DAC Retrieval)
    3. 生成语音 (DAC Synthesis + Fast Agent)
    4. 评估质量 (Supervisor)
    5. 迭代优化直到达到目标
    """

    def __init__(
        self,
        model=None,
        retrieval_system=None,
        supervisor=None,
        output_dir: str = "./results",
        device: str = "cuda",
        callback: Optional[Callable] = None
    ):
        """
        初始化 VoiceCloneAgent
        
        Args:
            model: AgentSteerTTS 模型实例
            retrieval_system: EmotionRetrievalSystem 实例
            supervisor: SupervisorAgent 实例
            output_dir: 输出目录
            device: 计算设备
            callback: 进度回调函数
        """
        self.model = model
        self.retrieval_system = retrieval_system
        self.supervisor = supervisor
        self.output_dir = output_dir
        self.device = device
        self.callback = callback
        
        # 状态管理
        self.generation_count = 0
        self.selected_voices: List[Dict[str, Any]] = []
        self.retrieved_audios: List[str] = []
        self.stopped = False
        
        logger.info("VoiceCloneAgent 初始化完成")

    def stop(self):
        """停止生成"""
        logger.info("收到停止请求")
        self.stopped = True

    def clear(self):
        """清空状态"""
        self.generation_count = 0
        self.selected_voices = []
        self.retrieved_audios = []
        self.stopped = False
        logger.info("状态已清空")

    def retrieve_emotion_audio(self, query: str, top_k: int = 5) -> List[str]:
        """
        检索情感音频 (DAC Retrieval Agent)
        
        Args:
            query: 情感描述查询
            top_k: 返回数量
            
        Returns:
            音频路径列表
        """
        if not self.retrieval_system:
            logger.warning("检索系统未初始化")
            return []
        
        try:
            results = self.retrieval_system.retrieve(query, top_k=top_k)
            audio_paths = [r.audio_path for r in results]
            self.retrieved_audios = audio_paths
            
            if self.callback:
                self.callback({
                    'event': 'retrieve_complete',
                    'audio_paths': audio_paths,
                    'query': query
                })
            
            return audio_paths
            
        except Exception as e:
            logger.error(f"检索失败: {e}")
            return []

    def generate_tts(
        self,
        text: str,
        reference_audio: str,
        emo_audio: Optional[str] = None,
        emo_vector: Optional[List[float]] = None,
        emo_alpha: float = 1.0,
        emo_merge_alpha: float = 1.0,
        save_path: Optional[str] = None
    ) -> Optional[str]:
        """
        生成 TTS 音频
        
        Args:
            text: 合成文本
            reference_audio: 参考说话人音频
            emo_audio: 情感参考音频
            emo_vector: 情感向量 [8维]
            emo_alpha: 情感强度
            emo_merge_alpha: 融合比例
            save_path: 保存路径
            
        Returns:
            生成的音频路径
        """
        logger.info(f"生成 TTS: text={text[:30]}...")
        
        # 如果没有模型，返回模拟路径
        if not self.model:
            logger.warning("模型未初始化，返回模拟路径")
            self.generation_count += 1
            return save_path or f"generated_{self.generation_count}.wav"
        
        try:
            # 实际调用模型生成
            # 这里需要与实际的 TTS 推理流程集成
            self.generation_count += 1
            
            if self.callback:
                self.callback({
                    'event': 'generate_complete',
                    'audio_path': save_path,
                    'text': text,
                    'emo_alpha': emo_alpha
                })
            
            return save_path
            
        except Exception as e:
            logger.error(f"TTS 生成失败: {e}")
            return None

    def generate_candidates(
        self,
        text: str,
        reference_audio: str,
        emotion_audio: str,
        emotion_text: str,
        alphas: List[float] = [0.6, 0.8, 1.0, 1.2, 1.4],
        fuse_weight: float = 0.6
    ) -> Dict[str, Any]:
        """
        生成候选音频 (A×5, B×1, C×5 = 11个候选)
        
        Args:
            text: 合成文本
            reference_audio: 参考音频
            emotion_audio: 情感音频
            emotion_text: 情感描述
            alphas: 强度列表
            fuse_weight: 融合权重
            
        Returns:
            候选清单
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        out_dir = os.path.join(self.output_dir, timestamp)
        os.makedirs(out_dir, exist_ok=True)
        
        # 解析情感文本为向量 (简化版本)
        emo_vector = [0.0] * 8  # 默认 calm
        emo_vector[7] = 1.0
        
        A_paths = {}
        C_paths = {}
        
        # 生成 A 系列 (纯音频驱动)
        for alpha in alphas:
            tag = f"{alpha:.2f}".replace(".", "p")
            a_path = os.path.join(out_dir, f"A_audio_alpha_{tag}.wav")
            self.generate_tts(
                text=text,
                reference_audio=reference_audio,
                emo_audio=emotion_audio,
                emo_vector=None,
                emo_alpha=alpha,
                emo_merge_alpha=1.0,
                save_path=a_path
            )
            A_paths[f"alpha_{tag}"] = a_path
        
        # 生成 B (纯向量驱动)
        b_path = os.path.join(out_dir, "B_vector.wav")
        self.generate_tts(
            text=text,
            reference_audio=reference_audio,
            emo_audio=None,
            emo_vector=emo_vector,
            emo_alpha=1.0,
            save_path=b_path
        )
        
        # 生成 C 系列 (融合)
        for alpha in alphas:
            tag = f"{alpha:.2f}".replace(".", "p")
            c_path = os.path.join(out_dir, f"C_fused_alpha_{tag}.wav")
            self.generate_tts(
                text=text,
                reference_audio=reference_audio,
                emo_audio=emotion_audio,
                emo_vector=emo_vector,
                emo_alpha=alpha,
                emo_merge_alpha=fuse_weight,
                save_path=c_path
            )
            C_paths[f"alpha_{tag}"] = c_path
        
        manifest = {
            "A_audio_driven": A_paths,
            "B_vector_driven": b_path,
            "C_fused": C_paths,
            "text": text,
            "reference_audio": reference_audio,
            "emotion_audio": emotion_audio,
            "output_dir": out_dir
        }
        
        return manifest

    def run(
        self,
        text: str,
        reference_audio: str,
        user_requirements: Optional[str] = None,
        emotion_audio: Optional[str] = None,
        num_to_gen: int = 3,
        max_iterations: int = 10
    ) -> Dict[str, Any]:
        """
        运行完整的语音克隆流程
        
        Args:
            text: 合成文本
            reference_audio: 参考说话人音频
            user_requirements: 用户情感需求描述
            emotion_audio: 情感参考音频 (可选)
            num_to_gen: 目标生成数量
            max_iterations: 最大迭代次数
            
        Returns:
            最终结果
        """
        logger.info("=" * 60)
        logger.info("开始语音克隆流程")
        logger.info("=" * 60)
        
        self.clear()
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = os.path.join(self.output_dir, timestamp)
        os.makedirs(self.output_dir, exist_ok=True)
        
        iteration = 0
        
        while len(self.selected_voices) < num_to_gen and iteration < max_iterations:
            if self.stopped:
                logger.info("收到停止信号，退出循环")
                break
            
            iteration += 1
            logger.info(f"\n迭代 {iteration}: 已选择 {len(self.selected_voices)}/{num_to_gen}")
            
            # Step 1: 检索情感音频 (如果需要)
            if not emotion_audio and user_requirements:
                retrieved = self.retrieve_emotion_audio(user_requirements, top_k=3)
                if retrieved:
                    emotion_audio = retrieved[0]
            
            # Step 2: 生成候选
            save_path = os.path.join(self.output_dir, f"gen_{self.generation_count + 1:03d}.wav")
            generated = self.generate_tts(
                text=text,
                reference_audio=reference_audio,
                emo_audio=emotion_audio,
                emo_alpha=1.0,
                save_path=save_path
            )
            
            if generated:
                # Step 3: 评估 (Supervisor)
                if self.supervisor:
                    critique = self.supervisor.analyze_audio(generated, target_emotion=user_requirements)
                    
                    if self.supervisor.should_adjust_alpha(critique):
                        # 重新生成
                        logger.info("Supervisor 建议调整 alpha")
                        continue
                
                # Step 4: 选择
                self.selected_voices.append({
                    "audio_path": generated,
                    "iteration": iteration,
                    "timestamp": datetime.now().isoformat()
                })
        
        # 保存结果
        results = {
            "total_iterations": iteration,
            "total_generations": self.generation_count,
            "selected_voices": self.selected_voices,
            "text": text,
            "reference_audio": reference_audio,
            "user_requirements": user_requirements,
            "output_dir": self.output_dir,
            "timestamp": datetime.now().isoformat()
        }
        
        logger.info(f"\n{'=' * 60}")
        logger.info(f"语音克隆完成! 选择了 {len(self.selected_voices)}/{num_to_gen} 个样本")
        logger.info(f"{'=' * 60}")
        
        return results
