import torch
from funasr import AutoModel
from typing import Tuple
import logging

logger = logging.getLogger(__name__)


class EmotionDetector:
    """情感识别模块，使用emotion2vec plus模型进行情感预测"""

    def __init__(
        self, model_name: str = "iic/emotion2vec_plus_large", device: str = None
    ):
        """
        初始化情感检测器

        Args:
            model_name: 模型名称
            device: 设备类型（'cuda'或'cpu'）
        """
        self.model_name = model_name
        self.device = (
            device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        )

        try:
            # 加载emotion2vec plus模型
            self.model = AutoModel(model=model_name, disable_update=True)

            logger.info(f"Emotion detector initialized with model: {model_name}")
            logger.info(f"Device: {self.device}")

        except Exception as e:
            logger.error(f"Failed to initialize emotion detector: {e}")
            raise

    def predict_emotion(self, audio_path: str) -> Tuple[str, float]:
        """
        预测音频的情感类别和置信度

        Args:
            audio_path: 音频文件路径

        Returns:
            predicted_label: 预测的情感标签
            confidence: 预测置信度
        """
        try:
            # 使用emotion2vec plus模型进行预测
            res = self.model.generate(
                audio_path,
                output_dir=None,
                granularity="utterance",
                extract_embedding=False,
            )

            # 解析结果
            if res and len(res) > 0:
                # 获取第一个结果（假设只有一个音频文件）
                item = res[0]

                # 找到最高分数的情感标签
                max_score = 0.0
                predicted_label = "unknown"

                for label, score in zip(item["labels"], item["scores"]):
                    if score > max_score:
                        max_score = score
                        predicted_label = label

                # 标准化标签名称（移除中文部分）
                if "/" in predicted_label:
                    predicted_label = predicted_label.split("/")[1]  # 取英文部分

                return predicted_label, max_score
            else:
                return "unknown", 0.0

        except Exception as e:
            logger.error(f"Failed to predict emotion for {audio_path}: {e}")
            return "unknown", 0.0

    def check_label_consistency(
        self, predicted_label: str, original_label: str
    ) -> bool:
        """
        检查预测标签与原始标签的一致性

        Args:
            predicted_label: 预测的情感标签
            original_label: 原始情感标签

        Returns:
            is_consistent: 是否一致
        """
        # 标准化标签格式（转换为小写并移除空格）
        pred_normalized = predicted_label.lower().strip()
        orig_normalized = original_label.lower().strip()

        # 直接比较
        if pred_normalized == orig_normalized:
            return True

        # 处理常见的标签变体
        label_mapping = {
            "happy": ["happiness", "joy", "excited"],
            "sad": ["sadness", "depressed", "melancholy"],
            "angry": ["anger", "mad", "furious"],
            "fearful": ["fear", "afraid", "scared", "terrified"],
            "surprised": ["surprise", "shocked", "amazed"],
            "disgusted": ["disgust", "repulsed"],
            "neutral": ["calm", "normal", "peaceful"],
            "other": ["unknown", "misc"],
        }

        # 检查映射关系
        for main_label, variants in label_mapping.items():
            if (pred_normalized == main_label and orig_normalized in variants) or (
                orig_normalized == main_label and pred_normalized in variants
            ):
                return True

        return False

    def evaluate_audio_emotion(
        self, audio_path: str, original_label: str, min_confidence: float = 0.8
    ) -> Tuple[bool, dict]:
        """
        评估音频的情感质量

        Args:
            audio_path: 音频文件路径
            original_label: 原始情感标签
            min_confidence: 最小置信度阈值

        Returns:
            is_qualified: 是否通过情感质量检查
            metrics: 情感指标字典
        """
        try:
            # 预测情感
            predicted_label, confidence = self.predict_emotion(audio_path)

            # 检查置信度
            confidence_ok = confidence >= min_confidence

            # 检查标签一致性
            label_consistent = self.check_label_consistency(
                predicted_label, original_label
            )

            # 综合判断
            is_qualified = confidence_ok and label_consistent

            metrics = {
                "predicted_label": predicted_label,
                "original_label": original_label,
                "confidence": confidence,
                "confidence_ok": confidence_ok,
                "label_consistent": label_consistent,
            }

            return is_qualified, metrics

        except Exception as e:
            logger.error(f"Failed to evaluate audio emotion for {audio_path}: {e}")
            return False, {"error": str(e)}
