#!/usr/bin/env python3
"""
AgentSteerTTS 完整测试脚本

用法:
    python scripts/test.py              # 运行所有模块测试
    python scripts/test.py --data_dir data  # 使用真实数据测试
    python scripts/test.py --module model   # 测试特定模块
"""

import os
import sys
import json
import argparse
import torch
import torch.nn as nn

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class SimulatedFeatureExtractor(nn.Module):
    """模拟 W2V-BERT 特征提取器"""
    def __init__(self, output_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 256, kernel_size=10, stride=5, padding=2), nn.GELU(),
            nn.Conv1d(256, 512, kernel_size=3, stride=2, padding=1), nn.GELU(),
            nn.Conv1d(512, output_dim, kernel_size=3, stride=2, padding=1), nn.GELU(),
        )
        
    def forward(self, audio):
        if audio.dim() == 2:
            audio = audio.unsqueeze(1)
        return self.net(audio)


def test_model(device="cpu"):
    """测试核心模型 (ADM + DAC + Fast Agent)"""
    print("\n" + "=" * 60)
    print("测试核心模型 (AgentSteerTTS)")
    print("=" * 60)
    
    from agentsteertts import AgentSteerTTS
    
    model = AgentSteerTTS(model_dim=512, semantic_dim=1024).to(device)
    model.eval()
    
    B, T = 2, 100
    speaker_semantic = torch.randn(B, 1024, T, device=device)
    speaker_lengths = torch.tensor([T, T-10], device=device)
    text_tokens = torch.randint(0, 256, (B, 50), device=device)
    text_lengths = torch.tensor([50, 45], device=device)
    
    tests = [
        ("基础前向传播", {}),
        ("音频情感控制", {"emo_semantic": torch.randn(B, 1024, 80, device=device), 
                        "emo_lengths": torch.tensor([80, 70], device=device)}),
        ("向量情感控制", {"emo_vector": torch.tensor([[0.6, 0.2, 0, 0, 0, 0.2], 
                                                      [0, 0.8, 0.2, 0, 0, 0]], device=device)}),
        ("Fast Agent 校准", {"emo_semantic": torch.randn(B, 1024, 80, device=device),
                            "emo_lengths": torch.tensor([80, 70], device=device),
                            "use_fast_calibration": True,
                            "target_embedding": torch.randn(B, 768, device=device)}),
    ]
    
    passed = 0
    with torch.no_grad():
        for name, kwargs in tests:
            try:
                output = model(speaker_semantic, speaker_lengths, text_tokens, text_lengths, **kwargs)
                print(f"[PASS] {name}: conditioning={output['conditioning'].shape}")
                passed += 1
            except Exception as e:
                print(f"[FAIL] {name}: {e}")
    
    # ADM 损失测试
    try:
        speaker_labels = torch.randint(0, 100, (B,), device=device)
        emotion_labels = torch.randint(0, 8, (B,), device=device)
        losses = model.compute_adm_losses(speaker_semantic, speaker_labels, emotion_labels, speaker_lengths)
        print(f"[PASS] ADM 损失: L_adv={losses['l_adv']:.4f}, L_orth={losses['l_orth']:.4f}")
        passed += 1
    except Exception as e:
        print(f"[FAIL] ADM 损失: {e}")
    
    return passed, 5


def test_retrieval(device="cpu"):
    """测试检索系统"""
    print("\n" + "=" * 60)
    print("测试检索系统 (EmotionRetrievalSystem)")
    print("=" * 60)
    
    from agentsteertts.retrieval import EmotionRetrievalSystem
    
    passed = 0
    
    # 初始化测试
    try:
        retriever = EmotionRetrievalSystem(device=device)
        print(f"[PASS] 初始化检索系统")
        passed += 1
    except Exception as e:
        print(f"[FAIL] 初始化检索系统: {e}")
        return passed, 3
    
    # 模拟数据库测试
    try:
        # 创建模拟数据
        retriever.audio_metadata = [
            {'audio_path': 'test1.wav', 'instruction': 'happy voice', 
             'dataset_name': 'test', 'speaker_name': 'spk1', 'json_path': 'test1.json'},
            {'audio_path': 'test2.wav', 'instruction': 'sad voice',
             'dataset_name': 'test', 'speaker_name': 'spk2', 'json_path': 'test2.json'},
        ]
        retriever.embeddings = torch.randn(2, 384, device=device)
        retriever.emotion_embeddings = torch.randn(2, 1280, device=device)
        print(f"[PASS] 创建模拟数据库")
        passed += 1
    except Exception as e:
        print(f"[FAIL] 创建模拟数据库: {e}")
    
    # 检索测试
    try:
        results = retriever.search_by_instruction("happy", top_k=2)
        print(f"[PASS] 指令检索: 返回 {len(results)} 条结果")
        passed += 1
    except Exception as e:
        print(f"[FAIL] 指令检索: {e}")
    
    return passed, 3


def test_agents(device="cpu"):
    """测试 Agent 模块"""
    print("\n" + "=" * 60)
    print("测试 Agent 模块 (Supervisor + VoiceClone)")
    print("=" * 60)
    
    from agentsteertts.agents import SupervisorAgent, VoiceCloneAgent
    from agentsteertts.agents.supervisor import DeviationType
    
    passed = 0
    
    # Supervisor 测试
    try:
        supervisor = SupervisorAgent(device=device)
        result = supervisor.analyze_audio("dummy.wav", target_emotion="happy")
        print(f"[PASS] Supervisor 分析: deviation={result.deviation_type.value}")
        passed += 1
    except Exception as e:
        print(f"[FAIL] Supervisor 分析: {e}")
    
    # Alpha 调整测试
    try:
        from agentsteertts.agents.supervisor import CritiqueResult
        critique = CritiqueResult(
            deviation_type=DeviationType.EMOTION_TOO_WEAK,
            confidence=0.8,
            critique_text="情感过弱",
            suggested_action="adjust_alpha"
        )
        new_alpha = supervisor.get_alpha_adjustment(critique, 1.0)
        print(f"[PASS] Alpha 调整: 1.0 -> {new_alpha:.2f}")
        passed += 1
    except Exception as e:
        print(f"[FAIL] Alpha 调整: {e}")
    
    # VoiceClone 测试
    try:
        agent = VoiceCloneAgent(output_dir="./test_output", device=device)
        agent.clear()
        print(f"[PASS] VoiceCloneAgent 初始化")
        passed += 1
    except Exception as e:
        print(f"[FAIL] VoiceCloneAgent 初始化: {e}")
    
    return passed, 3


def test_utils():
    """测试工具模块"""
    print("\n" + "=" * 60)
    print("测试工具模块 (Audio + Emotion)")
    print("=" * 60)
    
    from agentsteertts.utils import AudioProcessor, EmotionDetector
    
    passed = 0
    
    # AudioProcessor 测试
    try:
        processor = AudioProcessor(sample_rate=16000)
        audio = torch.randn(1, 16000)
        normalized = processor.normalize(audio)
        assert normalized.abs().max() <= 1.0
        print(f"[PASS] AudioProcessor 归一化")
        passed += 1
    except Exception as e:
        print(f"[FAIL] AudioProcessor 归一化: {e}")
    
    # Mel 计算测试
    try:
        mel = processor.compute_mel(audio)
        print(f"[PASS] Mel 频谱计算: shape={mel.shape}")
        passed += 1
    except Exception as e:
        print(f"[FAIL] Mel 频谱计算: {e}")
    
    # EmotionDetector 测试
    try:
        detector = EmotionDetector(device="cpu")
        vec = [0.5, 0.2, 0.1, 0.0, 0.1, 0.0, 0.1, 0.0]
        emo_dict = detector.emotion_vector_to_dict(vec)
        assert "happy" in emo_dict
        print(f"[PASS] EmotionDetector 向量转换")
        passed += 1
    except Exception as e:
        print(f"[FAIL] EmotionDetector 向量转换: {e}")
    
    return passed, 3


def test_with_data(data_dir, device="cpu"):
    """使用真实数据测试"""
    print("\n" + "=" * 60)
    print("真实数据端到端测试")
    print("=" * 60)
    
    from agentsteertts import AgentSteerTTS
    from agentsteertts.utils import AudioProcessor
    
    feature_extractor = SimulatedFeatureExtractor(output_dim=1024).to(device)
    model = AgentSteerTTS(model_dim=512, semantic_dim=1024).to(device)
    processor = AudioProcessor(sample_rate=16000)
    
    feature_extractor.eval()
    model.eval()
    
    # 查找数据文件
    wav_files = [f for f in os.listdir(data_dir) if f.endswith('.wav')][:2]
    if not wav_files:
        print("未找到 .wav 文件，跳过真实数据测试")
        return 0, 0
    
    EMOTION_MAP = {
        "happy": [1,0,0,0,0,0], "angry": [0,1,0,0,0,0], "sad": [0,0,1,0,0,0],
        "fear": [0,0,0,1,0,0], "disgust": [0,0,0,0,1,0], "surprise": [0,0,0,0,0,1],
        "neutral": [0,0,0,0,0,0]
    }
    
    passed = 0
    with torch.no_grad():
        for wav_file in wav_files:
            sample_id = wav_file.replace('.wav', '')
            print(f"\n测试样本: {sample_id}")
            
            try:
                # 加载音频
                audio = processor.load(os.path.join(data_dir, wav_file)).to(device)
                
                # 加载元数据
                json_path = os.path.join(data_dir, f"{sample_id}.json")
                emotion = "neutral"
                if os.path.exists(json_path):
                    with open(json_path, 'r', encoding='utf-8') as f:
                        emotion = json.load(f).get("emotion", "neutral")
                
                # 提取特征
                speaker_semantic = feature_extractor(audio)
                speaker_lengths = torch.tensor([speaker_semantic.shape[2]], device=device)
                emo_vector = torch.tensor([EMOTION_MAP.get(emotion.lower(), [0,0,0,0,0,0])], 
                                         dtype=torch.float32, device=device)
                text_tokens = torch.randint(0, 256, (1, 50), device=device)
                text_lengths = torch.tensor([50], device=device)
                
                # 推理
                output = model(speaker_semantic, speaker_lengths, text_tokens, text_lengths, 
                              emo_vector=emo_vector)
                print(f"  Emotion: {emotion}, Conditioning: {output['conditioning'].shape}")
                print(f"  [PASS]")
                passed += 1
                
            except Exception as e:
                print(f"  [FAIL] {e}")
    
    return passed, len(wav_files)


def main():
    parser = argparse.ArgumentParser(description="AgentSteerTTS 测试")
    parser.add_argument("--data_dir", type=str, default=None, help="数据目录路径")
    parser.add_argument("--device", type=str, default=None, help="设备 (auto/cpu/cuda)")
    parser.add_argument("--module", type=str, default=None, 
                       choices=["model", "retrieval", "agents", "utils", "all"],
                       help="测试特定模块")
    args = parser.parse_args()
    
    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    total_passed = 0
    total_tests = 0
    
    modules_to_test = ["model", "retrieval", "agents", "utils"] if args.module in [None, "all"] else [args.module]
    
    for module in modules_to_test:
        if module == "model":
            p, t = test_model(device)
        elif module == "retrieval":
            p, t = test_retrieval(device)
        elif module == "agents":
            p, t = test_agents(device)
        elif module == "utils":
            p, t = test_utils()
        total_passed += p
        total_tests += t
    
    # 真实数据测试
    if args.data_dir and os.path.exists(args.data_dir):
        p, t = test_with_data(args.data_dir, device)
        total_passed += p
        total_tests += t
    
    print("\n" + "=" * 60)
    print(f"测试完成: {total_passed}/{total_tests} 通过")
    print("=" * 60)
    
    return 0 if total_passed == total_tests else 1


if __name__ == "__main__":
    exit(main())