from typing import Dict, Any, Tuple
import logging
from audio_processor import AudioProcessor
from emotion_detector import EmotionDetector

logger = logging.getLogger(__name__)


class QualityMetrics:
    """质量指标模块，提供统一的质量评估接口"""

    def __init__(
        self,
        min_duration: float = 3.0,
        max_duration: float = 15.0,
        max_silence_ratio: float = 0.3,
        min_snr: float = 10.0,
        min_confidence: float = 0.8,
        model_name: str = "iic/emotion2vec_plus_large",
        device: str = None,
    ):
        """
        初始化质量指标评估器

        Args:
            min_duration: 最小音频时长（秒）
            max_duration: 最大音频时长（秒）
            max_silence_ratio: 最大静音比例
            min_snr: 最小信噪比（dB）
            min_confidence: 最小情感置信度
            model_name: emotion2vec模型名称
            device: 设备类型
        """
        self.min_duration = min_duration
        self.max_duration = max_duration
        self.max_silence_ratio = max_silence_ratio
        self.min_snr = min_snr
        self.min_confidence = min_confidence

        # 初始化处理器
        self.audio_processor = AudioProcessor()
        self.emotion_detector = EmotionDetector(model_name=model_name, device=device)

        logger.info(f"Quality metrics initialized with parameters:")
        logger.info(f"  Duration: {min_duration}-{max_duration}s")
        logger.info(f"  Max silence ratio: {max_silence_ratio}")
        logger.info(f"  Min SNR: {min_snr}dB")
        logger.info(f"  Min confidence: {min_confidence}")
        logger.info(f"  Model: {model_name}")
        logger.info(f"  Device: {device}")

    def evaluate_audio_quality(
        self, file_path: str, original_label: str
    ) -> Tuple[bool, Dict[str, Any]]:
        """
        评估音频质量

        Args:
            file_path: 音频文件路径
            original_label: 原始情感标签

        Returns:
            is_qualified: 是否通过所有质量检查
            metrics: 所有质量指标
        """
        try:
            # 加载音频数据
            audio_data, sample_rate = self.audio_processor.load_audio(file_path)

            # 1. 音频质量检查
            audio_qualified, audio_metrics = self.audio_processor.check_audio_quality(
                file_path,
                self.min_duration,
                self.max_duration,
                self.max_silence_ratio,
                self.min_snr,
            )

            # 2. 情感质量检查
            emotion_qualified, emotion_metrics = (
                self.emotion_detector.evaluate_audio_emotion(
                    file_path, original_label, self.min_confidence
                )
            )

            # 综合判断
            is_qualified = audio_qualified and emotion_qualified

            # 合并指标
            all_metrics = {
                "file_path": file_path,
                "original_label": original_label,
                "is_qualified": is_qualified,
                "audio_qualified": audio_qualified,
                "emotion_qualified": emotion_qualified,
                "audio_metrics": audio_metrics,
                "emotion_metrics": emotion_metrics,
            }

            return is_qualified, all_metrics

        except Exception as e:
            logger.error(f"Failed to evaluate audio quality for {file_path}: {e}")
            return False, {
                "file_path": file_path,
                "original_label": original_label,
                "is_qualified": False,
                "error": str(e),
            }

    def get_filter_reason(self, metrics: Dict[str, Any]) -> str:
        """
        获取过滤原因

        Args:
            metrics: 质量指标字典

        Returns:
            reason: 过滤原因
        """
        if metrics.get("error"):
            return f"Error: {metrics['error']}"

        reasons = []

        # 检查音频质量
        audio_metrics = metrics.get("audio_metrics", {})
        if not audio_metrics.get("duration_ok", True):
            duration = audio_metrics.get("duration", 0)
            reasons.append(
                f"Duration {duration:.2f}s not in range [{self.min_duration}, {self.max_duration}]s"
            )

        if not audio_metrics.get("silence_ok", True):
            silence_ratio = audio_metrics.get("silence_ratio", 0)
            reasons.append(
                f"Silence ratio {silence_ratio:.2%} > {self.max_silence_ratio:.2%}"
            )

        if not audio_metrics.get("snr_ok", True):
            snr = audio_metrics.get("snr", 0)
            reasons.append(f"SNR {snr:.2f}dB < {self.min_snr}dB")

        # 检查情感质量
        emotion_metrics = metrics.get("emotion_metrics", {})
        if not emotion_metrics.get("confidence_ok", True):
            confidence = emotion_metrics.get("confidence", 0)
            reasons.append(f"Confidence {confidence:.3f} < {self.min_confidence}")

        if not emotion_metrics.get("label_consistent", True):
            pred_label = emotion_metrics.get("predicted_label", "unknown")
            orig_label = emotion_metrics.get("original_label", "unknown")
            reasons.append(
                f"Label mismatch: predicted '{pred_label}' vs original '{orig_label}'"
            )

        return "; ".join(reasons) if reasons else "Passed all checks"

    def print_metrics_summary(self, all_metrics: list) -> None:
        """
        打印质量指标摘要

        Args:
            all_metrics: 所有音频的质量指标列表
        """
        total_count = len(all_metrics)
        qualified_count = sum(1 for m in all_metrics if m.get("is_qualified", False))
        rejected_count = total_count - qualified_count

        print(f"\n=== Quality Filter Summary ===")
        print(f"Total audio files: {total_count}")
        print(f"Qualified: {qualified_count}")
        print(f"Rejected: {rejected_count}")
        print(f"Qualification rate: {qualified_count/total_count*100:.2f}%")

        # 统计各种过滤原因
        filter_reasons = {}
        for metrics in all_metrics:
            if not metrics.get("is_qualified", False):
                reason = self.get_filter_reason(metrics)
                filter_reasons[reason] = filter_reasons.get(reason, 0) + 1

        if filter_reasons:
            print(f"\n=== Filter Reasons ===")
            for reason, count in sorted(
                filter_reasons.items(), key=lambda x: x[1], reverse=True
            ):
                print(f"  {reason}: {count}")

        # 统计质量指标分布
        print(f"\n=== Quality Metrics Distribution ===")
        durations = [
            m.get("audio_metrics", {}).get("duration", 0)
            for m in all_metrics
            if m.get("audio_metrics")
        ]
        silence_ratios = [
            m.get("audio_metrics", {}).get("silence_ratio", 0)
            for m in all_metrics
            if m.get("audio_metrics")
        ]
        snrs = [
            m.get("audio_metrics", {}).get("snr", 0)
            for m in all_metrics
            if m.get("audio_metrics")
        ]
        confidences = [
            m.get("emotion_metrics", {}).get("confidence", 0)
            for m in all_metrics
            if m.get("emotion_metrics")
        ]

        if durations:
            print(
                f"Duration: min={min(durations):.2f}s, max={max(durations):.2f}s, avg={sum(durations)/len(durations):.2f}s"
            )
        if silence_ratios:
            print(
                f"Silence ratio: min={min(silence_ratios):.2%}, max={max(silence_ratios):.2%}, avg={sum(silence_ratios)/len(silence_ratios):.2%}"
            )
        if snrs:
            print(
                f"SNR: min={min(snrs):.2f}dB, max={max(snrs):.2f}dB, avg={sum(snrs)/len(snrs):.2f}dB"
            )
        if confidences:
            print(
                f"Confidence: min={min(confidences):.3f}, max={max(confidences):.3f}, avg={sum(confidences)/len(confidences):.3f}"
            )
