#!/usr/bin/env python3
"""
RAVDESS数据集采样程序
用于从RAVDESS数据集中采样不同情绪的音频文件
"""

import os
import argparse
import random
import shutil
import csv
import json
from pathlib import Path
from typing import List, Dict, Tuple


class RAVDESSDataset:
    """RAVDESS数据集处理类"""

    # 情绪映射字典
    EMOTION_MAP = {
        "01": "neutral",
        "02": "calm",
        "03": "happy",
        "04": "sad",
        "05": "angry",
        "06": "fearful",
        "07": "disgust",
        "08": "surprised",
    }

    def __init__(self, dataset_path: str):
        """
        初始化数据集

        Args:
            dataset_path: 数据集根目录路径
        """
        self.dataset_path = Path(dataset_path)
        self.audio_files = self._scan_audio_files()

    def _scan_audio_files(self) -> Dict[str, List[str]]:
        """
        扫描数据集中的所有音频文件

        Returns:
            按情绪分类的音频文件字典
        """
        audio_files = {emotion: [] for emotion in self.EMOTION_MAP.values()}

        # 遍历所有Actor目录
        for actor_dir in self.dataset_path.iterdir():
            if not actor_dir.is_dir() or not actor_dir.name.startswith("Actor_"):
                continue

            actor_id = actor_dir.name
            speaker = actor_id
            gender = "male" if int(actor_id.split("_")[1]) % 2 == 1 else "female"

            # 遍历该Actor目录下的所有wav文件
            for wav_file in actor_dir.glob("*.wav"):
                # 解析文件名获取情绪信息
                filename_parts = wav_file.stem.split("-")
                if len(filename_parts) >= 7:
                    emotion_code = filename_parts[2]
                    if emotion_code in self.EMOTION_MAP:
                        emotion = self.EMOTION_MAP[emotion_code]
                        audio_files[emotion].append(
                            {
                                "file_path": str(wav_file),
                                "filename": wav_file.name,
                                "emotion": emotion,
                                "speaker": speaker,
                                "gender": gender,
                            }
                        )

        return audio_files

    def sample_emotion(self, emotion: str, num_samples: int) -> List[Dict]:
        """
        采样指定情绪的音频文件

        Args:
            emotion: 情绪类别
            num_samples: 采样数量

        Returns:
            采样的音频文件信息列表
        """
        if emotion not in self.EMOTION_MAP.values():
            raise ValueError(f"不支持的情绪类别: {emotion}")

        available_files = self.audio_files.get(emotion, [])
        if not available_files:
            raise ValueError(f"数据集中没有找到情绪为 {emotion} 的音频文件")

        # 如果请求的采样数量超过可用文件数量，使用所有可用文件
        if num_samples > len(available_files):
            print(
                f"警告: 请求采样 {num_samples} 个文件，但只有 {len(available_files)} 个可用文件"
            )
            num_samples = len(available_files)

        # 随机采样
        sampled_files = random.sample(available_files, num_samples)
        return sampled_files

    def get_available_emotions(self) -> List[str]:
        """获取数据集中可用的情绪类别"""
        return [emotion for emotion, files in self.audio_files.items() if files]

    def get_emotion_stats(self) -> Dict[str, int]:
        """获取各情绪类别的文件数量统计"""
        return {emotion: len(files) for emotion, files in self.audio_files.items()}


def save_samples(
    sampled_files: List[Dict],
    emotion: str,
    save_path: str,
    dataset_path: str,
    num_samples_requested: int,
):
    """
    保存采样的音频文件和元信息

    Args:
        sampled_files: 采样的音频文件信息列表
        emotion: 情绪类别
        save_path: 保存路径
        dataset_path: 数据集路径
        num_samples_requested: 请求的采样数量
    """
    save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)

    # 创建wavs目录
    wavs_dir = save_path / f"{emotion}_wavs"
    wavs_dir.mkdir(exist_ok=True)

    # 复制音频文件
    for file_info in sampled_files:
        src_path = Path(file_info["file_path"])
        dst_path = wavs_dir / file_info["filename"]
        shutil.copy2(src_path, dst_path)

    # 保存CSV元信息文件
    csv_path = save_path / f"{emotion}.csv"
    with open(csv_path, "w", newline="", encoding="utf-8") as csvfile:
        fieldnames = ["dataset_name", "wavfile_name", "emotion", "speaker", "gender"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for file_info in sampled_files:
            writer.writerow(
                {
                    "dataset_name": "RAVDESS",
                    "wavfile_name": file_info["filename"],
                    "emotion": file_info["emotion"],
                    "speaker": file_info["speaker"],
                    "gender": file_info["gender"],
                }
            )

    # 生成统计信息
    speakers = list(set(file_info["speaker"] for file_info in sampled_files))
    speakers.sort()

    # 保存JSON统计信息文件
    json_info = {
        "dataset": "RAVDESS",
        "emotion": emotion.capitalize(),
        "total_samples": len(sampled_files),
        "speaker_count": len(speakers),
        "speakers": speakers,
        "sample_count_requested": num_samples_requested,
        "dataset_dir": dataset_path,
    }

    json_path = save_path / f"{emotion}_info.json"
    with open(json_path, "w", encoding="utf-8") as jsonfile:
        json.dump(json_info, jsonfile, indent=2, ensure_ascii=False)

    print(f"成功保存 {len(sampled_files)} 个 {emotion} 情绪的音频文件")
    print(f"音频文件保存在: {wavs_dir}")
    print(f"元信息保存在: {csv_path}")
    print(f"统计信息保存在: {json_path}")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="RAVDESS数据集采样程序")
    parser.add_argument(
        "--dataset_path", type=str, required=True, help="RAVDESS数据集根目录路径"
    )
    parser.add_argument(
        "--emotion",
        type=str,
        required=True,
        choices=[
            "neutral",
            "calm",
            "happy",
            "sad",
            "angry",
            "fearful",
            "disgust",
            "surprised",
        ],
        help="要采样的情绪类别",
    )
    parser.add_argument("--num_samples", type=int, required=True, help="采样数量")
    parser.add_argument("--save_path", type=str, required=True, help="保存路径")
    parser.add_argument("--seed", type=int, default=42, help="随机种子 (默认: 42)")

    args = parser.parse_args()

    # 设置随机种子
    random.seed(args.seed)

    # 检查数据集路径
    if not os.path.exists(args.dataset_path):
        print(f"错误: 数据集路径不存在: {args.dataset_path}")
        return

    try:
        # 初始化数据集
        dataset = RAVDESSDataset(args.dataset_path)

        # 显示数据集统计信息
        print("数据集统计信息:")
        stats = dataset.get_emotion_stats()
        for emotion, count in stats.items():
            if count > 0:
                print(f"  {emotion}: {count} 个文件")

        # 采样指定情绪的音频文件
        sampled_files = dataset.sample_emotion(args.emotion, args.num_samples)

        # 保存采样结果
        save_samples(
            sampled_files,
            args.emotion,
            args.save_path,
            args.dataset_path,
            args.num_samples,
        )

    except Exception as e:
        print(f"错误: {e}")
        return


if __name__ == "__main__":
    main()
