#!/usr/bin/env python3
"""
ASVP-ESD数据集采样程序
用于从ASVP-ESD数据集中采样不同情绪的音频文件
"""

import os
import argparse
import random
import shutil
import csv
from pathlib import Path
from typing import List, Dict


class ASVPESDDataset:
    """ASVP-ESD数据集处理类"""

    # 情绪映射字典
    EMOTION_MAP = {
        "01": "boredom",
        "02": "neutral",
        "03": "happy",
        "04": "sad",
        "05": "angry",
        "06": "fearful",
        "07": "disgust",
        "08": "surprised",
        "09": "excited",
        "10": "pleasure",
        "11": "pain",
        "12": "disappointment",
        "13": "breath",
    }

    def __init__(self, dataset_path: str):
        """
        初始化数据集

        Args:
            dataset_path: 数据集根目录路径
        """
        self.dataset_path = Path(dataset_path)
        self.audio_files = self._scan_audio_files()

    def _parse_filename(self, filename: str) -> Dict:
        """
        解析ASVP-ESD文件名格式

        Filename format: 03-01-06-01-02-12-02-01-01-16.wav
        Identifiers:
        - Modality (03 = audio-only)
        - Vocal channel (01 = speech, 02 = non speech)
        - Emotion (01-13)
        - Emotional intensity (01 = normal, 02 = high)
        - Statement (rank)
        - Actor (even = male, odd = female)
        - Age (01 = above 65, 02 = 20-64, 03 = under 20, 04 = baby)
        - Source (01-02 = website/youtube, 03 = movies)
        - Language (01 = Chinese, 02 = English, 04 = French, others = Russian/Others)
        """
        if not filename.endswith(".wav"):
            return None

        # Remove .wav extension
        name_parts = filename[:-4].split("-")

        if len(name_parts) < 10:
            return None

        try:
            modality = name_parts[0]
            vocal_channel = name_parts[1]
            emotion_code = name_parts[2]
            intensity = name_parts[3]
            statement = name_parts[4]
            actor = name_parts[5]
            age = name_parts[6]
            source = name_parts[7]
            language = name_parts[8]

            # Parse emotion
            emotion = self.EMOTION_MAP.get(emotion_code, "unknown")

            # Parse gender (even = male, odd = female)
            actor_num = int(actor)
            gender = "male" if actor_num % 2 == 0 else "female"

            # Parse age
            age_map = {"01": "above_65", "02": "20_64", "03": "under_20", "04": "baby"}
            age_group = age_map.get(age, "unknown")

            # Parse language
            language_map = {"01": "Chinese", "02": "English", "04": "French"}
            language_abbr_map = {"01": "zh", "02": "en", "04": "fr"}
            language_name = language_map.get(language, "Russian_Others")
            language_abbr = language_abbr_map.get(language, "other")

            return {
                "emotion": emotion,
                "gender": gender,
                "speaker_id": actor,
                "age_group": age_group,
                "language": language_name,
                "language_abbr": language_abbr,
                "intensity": intensity,
                "source": source,
                "statement": statement,
            }

        except (ValueError, IndexError):
            return None

    def _scan_audio_files(self) -> Dict[str, List[Dict]]:
        """
        扫描数据集中的所有音频文件

        Returns:
            按情绪分类的音频文件字典
        """
        audio_files = {emotion: [] for emotion in self.EMOTION_MAP.values()}
        audio_files["unknown"] = []  # For files that can't be parsed

        # 遍历数据集目录下的所有wav文件
        for wav_file in self.dataset_path.rglob("*.wav"):
            if wav_file.is_file():
                # 解析文件名
                file_info = self._parse_filename(wav_file.name)

                if file_info:
                    emotion = file_info["emotion"]
                    audio_files[emotion].append(
                        {
                            "file_path": str(wav_file),
                            "filename": wav_file.name,
                            "emotion": emotion,
                            "gender": file_info["gender"],
                            "speaker_id": file_info["speaker_id"],
                            "age_group": file_info["age_group"],
                            "language": file_info["language"],
                            "intensity": file_info["intensity"],
                            "source": file_info["source"],
                            "statement": file_info["statement"],
                        }
                    )
                else:
                    # 无法解析的文件归类为unknown
                    audio_files["unknown"].append(
                        {
                            "file_path": str(wav_file),
                            "filename": wav_file.name,
                            "emotion": "unknown",
                            "gender": "unknown",
                            "speaker_id": "unknown",
                            "age_group": "unknown",
                            "language": "unknown",
                            "intensity": "unknown",
                            "source": "unknown",
                            "statement": "unknown",
                        }
                    )

        return audio_files

    def sample_emotion(
        self, emotion: str, num_samples: int, language_filter: str = None
    ) -> List[Dict]:
        """
        采样指定情绪的音频文件

        Args:
            emotion: 情绪类别
            num_samples: 采样数量
            language_filter: 语言筛选条件 (Chinese, English, French)

        Returns:
            采样的音频文件信息列表
        """
        available_files = self.audio_files.get(emotion, [])
        if not available_files:
            raise ValueError(f"数据集中没有找到情绪为 {emotion} 的音频文件")

        # 如果指定了语言筛选条件，过滤文件
        if language_filter:
            filtered_files = [
                f for f in available_files if f.get("language") == language_filter
            ]
            if not filtered_files:
                raise ValueError(
                    f"数据集中没有找到情绪为 {emotion} 且语言为 {language_filter} 的音频文件"
                )
            available_files = filtered_files
            print(
                f"语言筛选后，{emotion} 情绪的 {language_filter} 语言文件数量: {len(available_files)}"
            )

        # 如果请求的采样数量超过可用文件数量，使用所有可用文件
        if num_samples > len(available_files):
            print(
                f"警告: 请求采样 {num_samples} 个文件，但只有 {len(available_files)} 个可用文件"
            )
            num_samples = len(available_files)

        # 随机采样
        sampled_files = random.sample(available_files, num_samples)
        return sampled_files

    def get_available_emotions(self) -> List[str]:
        """获取数据集中可用的情绪类别"""
        return [emotion for emotion, files in self.audio_files.items() if files]

    def get_emotion_stats(self) -> Dict[str, int]:
        """获取各情绪类别的文件数量统计"""
        return {emotion: len(files) for emotion, files in self.audio_files.items()}


def save_samples(
    sampled_files: List[Dict],
    emotion: str,
    output_dir: str,
    dataset_path: str,
    num_samples_requested: int,
    language_filter: str = None,
):
    """
    保存采样的音频文件和元信息

    Args:
        sampled_files: 采样的音频文件信息列表
        emotion: 情绪类别
        output_dir: 输出目录
        dataset_path: 数据集路径
        num_samples_requested: 请求的采样数量
        language_filter: 语言筛选条件
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # 获取语言缩写
    language_abbr = "all"
    if language_filter:
        language_abbr_map = {"Chinese": "zh", "English": "en", "French": "fr"}
        language_abbr = language_abbr_map.get(language_filter, "other")

    # 创建带语言缩写的情绪子目录
    emotion_dir = output_path / f"{emotion}_{language_abbr}"
    emotion_dir.mkdir(exist_ok=True)

    # 复制音频文件
    copied_count = 0
    for file_info in sampled_files:
        src_path = Path(file_info["file_path"])
        dst_path = emotion_dir / file_info["filename"]
        try:
            shutil.copy2(src_path, dst_path)
            copied_count += 1
        except Exception as e:
            print(f"复制文件失败 {src_path}: {e}")

    # 保存CSV元信息文件
    csv_path = output_path / f"{emotion}_{language_abbr}.csv"
    with open(csv_path, "w", newline="", encoding="utf-8") as csvfile:
        fieldnames = [
            "dataset_name",
            "wav_filename",
            "emotion_label",
            "gender",
            "speaker_id",
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for file_info in sampled_files:
            writer.writerow(
                {
                    "dataset_name": "ASVP-ESD",
                    "wav_filename": file_info["filename"],
                    "emotion_label": file_info["emotion"],
                    "gender": file_info["gender"],
                    "speaker_id": file_info["speaker_id"],
                }
            )

    print(f"成功采样 {len(sampled_files)} 个 {emotion} 情绪的音频文件")
    print(f"复制了 {copied_count} 个音频文件到: {emotion_dir}")
    print(f"元信息保存在: {csv_path}")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="ASVP-ESD数据集采样程序")
    parser.add_argument("--emotion", type=str, required=True, help="目标情绪类别")
    parser.add_argument(
        "--dataset_root", type=str, required=True, help="ASVP-ESD数据集根目录路径"
    )
    parser.add_argument("--output_dir", type=str, required=True, help="输出目录路径")
    parser.add_argument("--sample_num", type=int, default=300, help="采样数量")
    parser.add_argument(
        "--language",
        type=str,
        choices=["Chinese", "English", "French"],
        help="语言筛选条件 (可选)",
    )

    args = parser.parse_args()

    # 初始化数据集
    try:
        dataset = ASVPESDDataset(args.dataset_root)
    except Exception as e:
        print(f"初始化数据集失败: {e}")
        return

    # 显示可用情绪类别
    available_emotions = dataset.get_available_emotions()
    print(f"数据集中可用的情绪类别: {available_emotions}")

    # 显示各情绪类别的文件数量
    stats = dataset.get_emotion_stats()
    print("各情绪类别的文件数量:")
    for emotion, count in stats.items():
        if count > 0:
            print(f"  {emotion}: {count}")

    # 检查请求的情绪类别是否可用
    if args.emotion not in available_emotions:
        print(f"错误: 数据集中没有找到情绪为 '{args.emotion}' 的音频文件")
        print(f"可用的情绪类别: {available_emotions}")
        return

    # 采样指定情绪的音频文件
    try:
        sampled_files = dataset.sample_emotion(
            args.emotion, args.sample_num, args.language
        )
    except ValueError as e:
        print(f"采样失败: {e}")
        return

    # 保存采样的文件
    save_samples(
        sampled_files,
        args.emotion,
        args.output_dir,
        args.dataset_root,
        args.sample_num,
        args.language,
    )


if __name__ == "__main__":
    main()
