#!/usr/bin/env python3
"""
语音质量过滤CLI工具

用于过滤语音数据集，基于以下质量指标：
1. 音频持续时间（3-15秒）
2. 静音比例（<30%）
3. 信噪比（>10dB）
4. 情感识别置信度（>0.8）
5. 标签一致性
"""

import argparse
import pandas as pd
import os
import logging
from pathlib import Path
from tqdm import tqdm
from typing import List, Dict, Any, Optional

from quality_metrics import QualityMetrics

# 配置日志
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def parse_arguments():
    """解析命令行参数"""
    parser = argparse.ArgumentParser(
        description="语音质量过滤工具",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例用法:
  python filter.py --input data.csv --output filtered_data.csv --data-dir /path/to/audio/files
  python filter.py --input data.csv --output filtered_data.csv --data-dir /path/to/audio/files --min-duration 2.0 --max-duration 20.0
  python filter.py --input data.csv --output filtered_data.csv --data-dir /path/to/audio/files --max-count 1000
        """,
    )

    # 必需参数
    parser.add_argument(
        "--input", "-i", required=True, help="输入CSV文件路径（包含数据集信息）"
    )
    parser.add_argument(
        "--output", "-o", required=True, help="输出CSV文件路径（过滤后的数据）"
    )
    parser.add_argument("--data-dir", "-d", required=True, help="音频文件根目录")

    # 可选参数
    parser.add_argument(
        "--min-duration", type=float, default=3.0, help="最小音频时长（秒），默认3.0"
    )
    parser.add_argument(
        "--max-duration", type=float, default=15.0, help="最大音频时长（秒），默认15.0"
    )
    parser.add_argument(
        "--max-silence-ratio", type=float, default=0.3, help="最大静音比例，默认0.3"
    )
    parser.add_argument(
        "--min-snr", type=float, default=10.0, help="最小信噪比（dB），默认10.0"
    )
    parser.add_argument(
        "--min-confidence", type=float, default=0.8, help="最小情感置信度，默认0.8"
    )
    parser.add_argument("--verbose", "-v", action="store_true", help="详细输出模式")
    parser.add_argument(
        "--max-count",
        type=int,
        default=None,
        help="最大符合要求的数据数量，达到该数量后停止过滤，默认不限制",
    )

    return parser.parse_args()


def load_csv_data(csv_path: str) -> pd.DataFrame:
    """
    加载CSV数据

    Args:
        csv_path: CSV文件路径

    Returns:
        DataFrame: 包含数据集信息的DataFrame
    """
    try:
        df = pd.read_csv(csv_path)
        required_columns = [
            "dataset_name",
            "wav_filename",
            "emotion_label",
            "gender",
            "speaker_id",
        ]

        # 检查必需列
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(f"CSV文件缺少必需列: {missing_columns}")

        logger.info(f"成功加载CSV文件: {csv_path}")
        logger.info(f"数据行数: {len(df)}")
        logger.info(f"列名: {list(df.columns)}")

        return df

    except Exception as e:
        logger.error(f"加载CSV文件失败: {e}")
        raise


def build_audio_path(data_dir: str, dataset_name: str, wav_filename: str) -> str:
    """
    构建音频文件完整路径

    Args:
        data_dir: 数据根目录
        dataset_name: 数据集名称（未使用，保留参数兼容性）
        wav_filename: 音频文件路径（相对路径）

    Returns:
        str: 音频文件完整路径
    """
    # CSV文件中的wav_filename已经包含相对路径，直接拼接即可
    return os.path.join(data_dir, wav_filename)


def filter_audio_dataset(
    df: pd.DataFrame,
    data_dir: str,
    quality_metrics: QualityMetrics,
    max_count: Optional[int] = None,
) -> List[Dict[str, Any]]:
    """
    过滤音频数据集

    Args:
        df: 输入数据DataFrame
        data_dir: 音频文件根目录
        quality_metrics: 质量评估器
        max_count: 最大符合要求的数据数量，达到该数量后停止过滤，默认不限制

    Returns:
        List[Dict]: 过滤后的数据列表
    """
    filtered_data = []
    all_metrics = []

    if max_count is not None:
        logger.info(
            f"开始处理 {len(df)} 个音频文件，目标收集 {max_count} 个符合要求的数据..."
        )
    else:
        logger.info(f"开始处理 {len(df)} 个音频文件...")

    # 使用tqdm显示进度
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing audio files"):
        try:
            # 构建音频文件路径
            audio_path = build_audio_path(
                data_dir, row["dataset_name"], row["wav_filename"]
            )

            # 检查文件是否存在
            if not os.path.exists(audio_path):
                logger.warning(f"音频文件不存在: {audio_path}")
                continue

            # 评估音频质量
            is_qualified, metrics = quality_metrics.evaluate_audio_quality(
                audio_path, row["emotion_label"]
            )

            all_metrics.append(metrics)

            # 如果通过质量检查，添加到过滤后的数据
            if is_qualified:
                filtered_data.append(
                    {
                        "dataset_name": row["dataset_name"],
                        "wav_filename": row["wav_filename"],
                        "emotion_label": row["emotion_label"],
                        "gender": row["gender"],
                        "speaker_id": row["speaker_id"],
                    }
                )

                # 检查是否达到数量限制
                if max_count is not None and len(filtered_data) >= max_count:
                    logger.info(f"已达到指定数量限制 ({max_count})，停止过滤")
                    break

            # 详细输出模式
            if args.verbose:
                status = "✓" if is_qualified else "✗"
                reason = quality_metrics.get_filter_reason(metrics)
                print(f"{status} {row['wav_filename']}: {reason}")

        except Exception as e:
            logger.error(f"处理文件 {row['wav_filename']} 时出错: {e}")
            continue

    # 打印质量指标摘要
    quality_metrics.print_metrics_summary(all_metrics)

    return filtered_data


def save_filtered_data(filtered_data: List[Dict[str, Any]], output_path: str):
    """
    保存过滤后的数据

    Args:
        filtered_data: 过滤后的数据列表
        output_path: 输出文件路径
    """
    try:
        df_filtered = pd.DataFrame(filtered_data)
        df_filtered.to_csv(output_path, index=False)

        logger.info(f"成功保存过滤后的数据到: {output_path}")
        logger.info(f"过滤后数据行数: {len(df_filtered)}")

    except Exception as e:
        logger.error(f"保存过滤后的数据失败: {e}")
        raise


def main():
    """主函数"""
    global args
    args = parse_arguments()

    # 设置日志级别
    if args.verbose:
        logging.getLogger().setLevel(logging.DEBUG)

    logger.info("=== 语音质量过滤工具 ===")
    logger.info(f"输入文件: {args.input}")
    logger.info(f"输出文件: {args.output}")
    logger.info(f"数据目录: {args.data_dir}")
    logger.info(f"过滤参数:")
    logger.info(f"  音频时长: {args.min_duration}-{args.max_duration}秒")
    logger.info(f"  静音比例: <{args.max_silence_ratio:.1%}")
    logger.info(f"  信噪比: >{args.min_snr}dB")
    logger.info(f"  情感置信度: >{args.min_confidence}")
    if args.max_count is not None:
        logger.info(f"  最大符合要求数据数量: {args.max_count}")

    try:
        # 1. 加载CSV数据
        df = load_csv_data(args.input)

        # 2. 初始化质量评估器
        quality_metrics = QualityMetrics(
            min_duration=args.min_duration,
            max_duration=args.max_duration,
            max_silence_ratio=args.max_silence_ratio,
            min_snr=args.min_snr,
            min_confidence=args.min_confidence,
        )

        # 3. 过滤音频数据集
        filtered_data = filter_audio_dataset(
            df, args.data_dir, quality_metrics, args.max_count
        )

        # 4. 保存过滤后的数据
        save_filtered_data(filtered_data, args.output)

        logger.info("=== 过滤完成 ===")

    except Exception as e:
        logger.error(f"处理过程中出现错误: {e}")
        raise


if __name__ == "__main__":
    main()
