"""
Emotion Detection Utilities

情感检测和嵌入提取工具
"""

import logging
from typing import Optional, Dict, Any, List
from pathlib import Path
import numpy as np

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class EmotionDetector:
    """
    情感检测器
    
    使用 emotion2vec 或其他模型进行情感检测
    """
    
    # 标准情感标签
    EMOTION_LABELS = ["happy", "angry", "sad", "afraid", "disgusted", 
                      "melancholic", "surprised", "calm", "neutral"]
    
    def __init__(self, model_name: str = "iic/emotion2vec_plus_large", 
                 device: str = "cpu"):
        """
        初始化情感检测器
        
        Args:
            model_name: 模型名称
            device: 计算设备
        """
        self.model_name = model_name
        self.device = device
        self.model = None
        
        self._try_load_model()

    def _try_load_model(self):
        """尝试加载情感检测模型"""
        try:
            from funasr import AutoModel
            self.model = AutoModel(
                model=self.model_name,
                hub="ms",
                device=self.device
            )
            logger.info(f"情感检测模型加载成功: {self.model_name}")
        except ImportError:
            logger.warning("funasr 未安装，情感检测将使用模拟模式")
        except Exception as e:
            logger.warning(f"加载情感模型失败: {e}")

    def get_embedding(self, audio_path: str) -> Optional[np.ndarray]:
        """
        提取情感嵌入
        
        Args:
            audio_path: 音频文件路径
            
        Returns:
            情感嵌入向量
        """
        if not Path(audio_path).exists():
            logger.error(f"音频文件不存在: {audio_path}")
            return None
        
        if self.model:
            try:
                result = self.model.generate(
                    audio_path,
                    granularity="utterance",
                    extract_embedding=True
                )
                if result and len(result) > 0 and "feats" in result[0]:
                    return result[0]["feats"]
            except Exception as e:
                logger.error(f"提取情感嵌入失败: {e}")
        
        # 返回模拟嵌入
        return np.random.randn(1, 1280).astype(np.float32)

    def get_emotion_scores(self, audio_path: str) -> Optional[Dict[str, float]]:
        """
        获取情感分数
        
        Args:
            audio_path: 音频文件路径
            
        Returns:
            情感标签到分数的映射
        """
        if not Path(audio_path).exists():
            logger.error(f"音频文件不存在: {audio_path}")
            return None
        
        if self.model:
            try:
                result = self.model.generate(
                    audio_path,
                    granularity="utterance",
                    extract_embedding=True
                )
                if result and len(result) > 0:
                    if "scores" in result[0] and "labels" in result[0]:
                        scores = result[0]["scores"]
                        labels = result[0]["labels"]
                        return {label: float(score) for label, score in zip(labels, scores)}
            except Exception as e:
                logger.error(f"获取情感分数失败: {e}")
        
        # 返回模拟分数
        return {label: np.random.random() for label in self.EMOTION_LABELS}

    def get_dominant_emotion(self, audio_path: str) -> Optional[str]:
        """
        获取主导情感
        
        Args:
            audio_path: 音频文件路径
            
        Returns:
            主导情感标签
        """
        scores = self.get_emotion_scores(audio_path)
        if scores:
            return max(scores, key=scores.get)
        return None

    def compute_similarity(self, audio_path1: str, audio_path2: str) -> Optional[float]:
        """
        计算两个音频的情感相似度
        
        Args:
            audio_path1: 第一个音频路径
            audio_path2: 第二个音频路径
            
        Returns:
            余弦相似度 (0-1)
        """
        emb1 = self.get_embedding(audio_path1)
        emb2 = self.get_embedding(audio_path2)
        
        if emb1 is None or emb2 is None:
            return None
        
        # 展平
        emb1 = emb1.flatten()
        emb2 = emb2.flatten()
        
        # 计算余弦相似度
        dot = np.dot(emb1, emb2)
        norm1 = np.linalg.norm(emb1)
        norm2 = np.linalg.norm(emb2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        cos_sim = dot / (norm1 * norm2)
        return float((cos_sim + 1) / 2)  # 归一化到 0-1

    @staticmethod
    def emotion_vector_to_dict(vector: List[float]) -> Dict[str, float]:
        """
        情感向量转字典
        
        Args:
            vector: 8维情感向量
            
        Returns:
            情感字典
        """
        labels = ["happy", "angry", "sad", "afraid", "disgusted", 
                  "melancholic", "surprised", "calm"]
        return {label: float(v) for label, v in zip(labels, vector)}

    @staticmethod
    def emotion_dict_to_vector(emo_dict: Dict[str, float]) -> List[float]:
        """
        情感字典转向量
        
        Args:
            emo_dict: 情感字典
            
        Returns:
            8维情感向量
        """
        labels = ["happy", "angry", "sad", "afraid", "disgusted", 
                  "melancholic", "surprised", "calm"]
        return [float(emo_dict.get(label, 0.0)) for label in labels]
