#!/usr/bin/env python3
"""
TESS数据集文件复制程序
根据指定的情绪参数，从源目录复制符合条件的wav文件到输出目录
文件名格式: YAF_youth_fear.wav
"""

import os
import argparse
import shutil
from pathlib import Path
from typing import List, Tuple


# 支持的情绪类型
SUPPORTED_EMOTIONS = ["fear", "happy", "sad", "angry", "disgust", "surprise", "neutral"]


def parse_emotion_from_filename(filename: str) -> str:
    """
    从TESS格式的文件名中解析情绪信息

    Args:
        filename: 文件名 (如: YAF_youth_fear.wav)

    Returns:
        情绪字符串，如果无法解析则返回空字符串
    """
    # 移除.wav扩展名
    name_without_ext = filename.replace(".wav", "")

    # 按下划线分割文件名
    parts = name_without_ext.split("_")

    # 最后一个部分应该是情绪
    if len(parts) >= 3:
        emotion = parts[-1].lower()
        return emotion

    return ""


def scan_wav_files(source_dir: str) -> List[Tuple[str, str]]:
    """
    扫描源目录下的所有wav文件

    Args:
        source_dir: 源目录路径

    Returns:
        文件路径和情绪信息的元组列表
    """
    wav_files = []
    source_path = Path(source_dir)

    if not source_path.exists():
        raise ValueError(f"源目录不存在: {source_dir}")

    # 递归查找所有.wav文件
    for wav_file in source_path.rglob("*.wav"):
        emotion = parse_emotion_from_filename(wav_file.name)
        if emotion:
            wav_files.append((str(wav_file), emotion))

    return wav_files


def filter_files_by_emotion(
    wav_files: List[Tuple[str, str]], target_emotion: str
) -> List[str]:
    """
    根据目标情绪过滤文件

    Args:
        wav_files: 文件路径和情绪信息的元组列表
        target_emotion: 目标情绪

    Returns:
        匹配的文件路径列表
    """
    matched_files = []

    for file_path, emotion in wav_files:
        if emotion == target_emotion.lower():
            matched_files.append(file_path)

    return matched_files


def copy_files_to_output(matched_files: List[str], output_dir: str) -> int:
    """
    将匹配的文件复制到输出目录

    Args:
        matched_files: 匹配的文件路径列表
        output_dir: 输出目录路径

    Returns:
        成功复制的文件数量
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    copied_count = 0

    for file_path in matched_files:
        try:
            src_path = Path(file_path)
            dst_path = output_path / src_path.name

            shutil.copy2(src_path, dst_path)
            copied_count += 1

        except Exception as e:
            print(f"复制文件失败 {file_path}: {e}")

    return copied_count


def print_statistics(
    source_dir: str,
    target_emotion: str,
    total_files: int,
    matched_files: List[str],
    copied_count: int,
    output_dir: str,
):
    """
    打印操作统计信息

    Args:
        source_dir: 源目录
        target_emotion: 目标情绪
        total_files: 扫描到的总文件数
        matched_files: 匹配的文件列表
        copied_count: 成功复制的文件数
        output_dir: 输出目录
    """
    print(f"\n=== 操作统计 ===")
    print(f"源目录: {source_dir}")
    print(f"目标情绪: {target_emotion}")
    print(f"扫描到的wav文件总数: {total_files}")
    print(f"匹配的文件数量: {len(matched_files)}")
    print(f"成功复制的文件数量: {copied_count}")
    print(f"输出目录: {output_dir}")

    if copied_count > 0:
        print(f"\n复制的文件:")
        for file_path in matched_files:
            filename = Path(file_path).name
            print(f"  - {filename}")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="TESS数据集文件复制程序")
    parser.add_argument("--source_dir", type=str, required=True, help="源目录路径")
    parser.add_argument(
        "--emotion",
        type=str,
        required=True,
        choices=SUPPORTED_EMOTIONS,
        help="目标情绪类型",
    )
    parser.add_argument("--output_dir", type=str, required=True, help="输出目录路径")
    parser.add_argument("--verbose", action="store_true", help="显示详细信息")

    args = parser.parse_args()

    try:
        # 扫描源目录下的所有wav文件
        print(f"正在扫描目录: {args.source_dir}")
        wav_files = scan_wav_files(args.source_dir)

        if not wav_files:
            print("未找到任何wav文件")
            return

        # 过滤出匹配目标情绪的文件
        matched_files = filter_files_by_emotion(wav_files, args.emotion)

        if not matched_files:
            print(f"未找到情绪为 '{args.emotion}' 的文件")
            return

        # 复制文件到输出目录
        print(f"正在复制文件到: {args.output_dir}")
        copied_count = copy_files_to_output(matched_files, args.output_dir)

        # 打印统计信息
        print_statistics(
            args.source_dir,
            args.emotion,
            len(wav_files),
            matched_files,
            copied_count,
            args.output_dir,
        )

        print(f"\n操作完成!")

    except Exception as e:
        print(f"错误: {e}")
        return


if __name__ == "__main__":
    main()
