#!/usr/bin/env python3
"""
MELD数据集采样程序
用于从MELD数据集中采样不同情绪的音频文件
"""

import os
import argparse
import random
import shutil
import csv
import json
from pathlib import Path
from typing import List, Dict, Tuple
import logging


class MELDDataset:
    """MELD数据集处理类"""

    # 情绪映射字典
    EMOTION_MAPPING = {
        "neutral",
        "joy",
        "sadness",
        "anger",
        "surprise",
        "fear",
        "disgust",
    }

    # 说话人性别映射
    SPEAKER_GENDER_MAP = {
        "Monica": "female",
        "Rachel": "female",
        "Phoebe": "female",
        "Susan": "female",
        "Ross": "male",
        "Joey": "male",
        "Chandler": "male",
    }

    def __init__(self, dataset_path: str):
        """
        初始化数据集

        Args:
            dataset_path: 数据集根目录路径
        """
        self.dataset_path = Path(dataset_path)
        self.emotion_data = self._scan_emotion_data()

    def _load_csv_files(self) -> List[Dict]:
        """
        加载所有CSV文件数据

        Returns:
            合并的数据列表
        """
        all_data = []
        csv_files = [
            "train_sent_emo_clean.csv",
            "dev_sent_emo_clean.csv",
            "test_sent_emo_clean.csv",
        ]

        for csv_file in csv_files:
            csv_path = self.dataset_path / csv_file
            if not csv_path.exists():
                logging.warning(f"CSV文件不存在: {csv_path}")
                continue

            try:
                with open(csv_path, "r", encoding="utf-8") as f:
                    # 读取标题行以获取字段数量
                    first_line = f.readline().strip()
                    expected_field_count = len(first_line.split(","))
                    f.seek(0)  # 重置文件指针

                    reader = csv.DictReader(f)
                    for row_num, row in enumerate(reader, 1):
                        # 检查字段数量是否匹配
                        actual_field_count = len(list(row.values()))
                        if actual_field_count != expected_field_count:
                            logging.warning(
                                f"跳过无效行 {csv_file}:{row_num}, 字段数量不匹配 (期望:{expected_field_count}, 实际:{actual_field_count})"
                            )
                            continue

                        # 添加数据来源信息
                        row["source_split"] = csv_file.split("_")[0]  # train/dev/test
                        all_data.append(row)
            except Exception as e:
                logging.error(f"读取CSV文件失败 {csv_path}: {e}")
                continue

        return all_data

    def _generate_wav_filename(self, dialogue_id: str, utterance_id: str) -> str:
        """
        生成WAV文件名

        Args:
            dialogue_id: 对话ID
            utterance_id: 话语ID

        Returns:
            WAV文件名
        """
        return f"dia{dialogue_id}_utt{utterance_id}.wav"

    def _scan_emotion_data(self) -> Dict[str, List[Dict]]:
        """
        扫描并分类情绪数据

        Returns:
            按情绪分类的数据字典
        """
        emotion_data = {emotion: [] for emotion in self.EMOTION_MAPPING}
        all_data = self._load_csv_files()

        # 打印原始统计信息（应用约束前）
        print("原始数据集统计信息（应用约束前）:")
        raw_emotion_counts = {}
        for emotion in self.EMOTION_MAPPING:
            raw_emotion_counts[emotion] = 0

        for row in all_data:
            emotion = row.get("Emotion", "").lower()
            if emotion in self.EMOTION_MAPPING:
                raw_emotion_counts[emotion] += 1

        for emotion, count in raw_emotion_counts.items():
            print(f"  {emotion}: {count} 个原始样本")

        # 为每个情绪单独处理
        for target_emotion in self.EMOTION_MAPPING:
            # 只考虑当前情绪的对话
            emotion_dialogues = {}
            for row in all_data:
                emotion = row.get("Emotion", "").lower()
                if emotion == target_emotion:
                    dialogue_id = row.get("Dialogue_ID", "")
                    speaker = row.get("Speaker", "")
                    if dialogue_id not in emotion_dialogues:
                        emotion_dialogues[dialogue_id] = set()
                    emotion_dialogues[dialogue_id].add(speaker)

            # 找出单说话人的对话
            single_speaker_dialogues = set()
            for dialogue_id, speakers in emotion_dialogues.items():
                if len(speakers) == 1:
                    single_speaker_dialogues.add(dialogue_id)

            # 收集单说话人对话中的所有目标情绪样本
            for row in all_data:
                emotion = row.get("Emotion", "").lower()
                dialogue_id = row.get("Dialogue_ID", "")

                if (
                    emotion == target_emotion
                    and dialogue_id in single_speaker_dialogues
                ):
                    # 获取说话人性别
                    speaker = row.get("Speaker", "")
                    gender = self.SPEAKER_GENDER_MAP.get(speaker, "unknown")

                    # 生成WAV文件名
                    wav_filename = self._generate_wav_filename(
                        row.get("Dialogue_ID", ""), row.get("Utterance_ID", "")
                    )

                    emotion_data[target_emotion].append(
                        {
                            "dialogue_id": row.get("Dialogue_ID", ""),
                            "utterance_id": row.get("Utterance_ID", ""),
                            "wav_filename": wav_filename,
                            "speaker": speaker,
                            "gender": gender,
                            "emotion": target_emotion,
                            "source_split": row.get("source_split", ""),
                            "utterance": row.get("Utterance", ""),
                            "season": row.get("Season", ""),
                            "episode": row.get("Episode", ""),
                        }
                    )

        print("应用约束后的统计信息:")
        for emotion, samples in emotion_data.items():
            print(f"  {emotion}: {len(samples)} 个有效样本")

        return emotion_data

    def sample_emotion(self, emotion: str, num_samples: int) -> List[Dict]:
        """
        采样指定情绪的音频文件

        Args:
            emotion: 情绪类别
            num_samples: 采样数量

        Returns:
            采样的音频文件信息列表
        """
        if emotion not in self.EMOTION_MAPPING:
            raise ValueError(f"不支持的情绪类别: {emotion}")

        available_files = self.emotion_data.get(emotion, [])
        if not available_files:
            raise ValueError(f"数据集中没有找到情绪为 {emotion} 的音频文件")

        # 如果请求的采样数量超过可用文件数量，使用所有可用文件
        if num_samples > len(available_files):
            logging.warning(
                f"请求采样 {num_samples} 个文件，但只有 {len(available_files)} 个可用文件"
            )
            num_samples = len(available_files)

        # 随机采样
        sampled_files = random.sample(available_files, num_samples)
        return sampled_files

    def get_available_emotions(self) -> List[str]:
        """获取数据集中可用的情绪类别"""
        return [emotion for emotion, files in self.emotion_data.items() if files]

    def get_emotion_stats(self) -> Dict[str, int]:
        """获取各情绪类别的文件数量统计"""
        return {emotion: len(files) for emotion, files in self.emotion_data.items()}


def save_samples(
    sampled_files: List[Dict],
    emotion: str,
    save_path: str,
    dataset_path: str,
    num_samples_requested: int,
):
    """
    保存采样的音频文件和元信息

    Args:
        sampled_files: 采样的音频文件信息列表
        emotion: 情绪类别
        save_path: 保存路径
        dataset_path: 数据集路径
        num_samples_requested: 请求的采样数量
    """
    save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)

    # 创建wavs目录
    wavs_dir = save_path / f"{emotion}_wavs"
    wavs_dir.mkdir(exist_ok=True)

    # 查找可能的音频文件位置
    audio_source_dirs = []
    dataset_root = Path(dataset_path)

    # 检查常见的音频目录名称
    possible_dirs = [
        "train_splits",
        "dev_splits_complete",
        "output_repeated_splits_test",
        "train_splits_complete",
        "dev_splits",
        "test_splits",
        "audio",
        "wavs",
        "wav_files",
    ]

    for dir_name in possible_dirs:
        audio_dir = dataset_root / dir_name
        if audio_dir.exists() and audio_dir.is_dir():
            audio_source_dirs.append(audio_dir)

    # 复制音频文件
    copied_files = []
    for file_info in sampled_files:
        wav_filename = file_info["wav_filename"]
        dst_path = wavs_dir / wav_filename

        # 尝试在各个可能的目录中查找源文件
        found = False
        for source_dir in audio_source_dirs:
            src_path = source_dir / wav_filename
            if src_path.exists():
                try:
                    shutil.copy2(src_path, dst_path)
                    copied_files.append(file_info)
                    found = True
                    break
                except Exception as e:
                    logging.error(f"复制文件失败 {src_path} -> {dst_path}: {e}")

        if not found:
            logging.warning(f"未找到音频文件: {wav_filename}")

    # 如果没有找到任何音频文件，仍然生成元数据文件
    if not copied_files:
        logging.warning("未找到任何音频文件，仅生成元数据文件")
        copied_files = sampled_files

    # 保存CSV元信息文件
    csv_path = save_path / f"{emotion}.csv"
    with open(csv_path, "w", newline="", encoding="utf-8") as csvfile:
        fieldnames = ["dataset_name", "wavfile_name", "emotion", "speaker", "gender"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for file_info in copied_files:
            writer.writerow(
                {
                    "dataset_name": "MELD",
                    "wavfile_name": file_info["wav_filename"],
                    "emotion": file_info["emotion"],
                    "speaker": file_info["speaker"],
                    "gender": file_info["gender"],
                }
            )

    # 生成统计信息
    speakers = list(set(file_info["speaker"] for file_info in copied_files))
    speakers.sort()

    # 保存JSON统计信息文件
    json_info = {
        "dataset": "MELD",
        "emotion": emotion.capitalize(),
        "total_samples": len(copied_files),
        "speaker_count": len(speakers),
        "speakers": speakers,
        "sample_count_requested": num_samples_requested,
        "dataset_dir": dataset_path,
    }

    json_path = save_path / f"{emotion}_info.json"
    with open(json_path, "w", encoding="utf-8") as jsonfile:
        json.dump(json_info, jsonfile, indent=2, ensure_ascii=False)

    print(f"成功保存 {len(copied_files)} 个 {emotion} 情绪的样本")
    print(f"音频文件保存在: {wavs_dir}")
    print(f"元信息保存在: {csv_path}")
    print(f"统计信息保存在: {json_path}")

    if len(audio_source_dirs) == 0:
        print("警告: 未找到音频文件目录，请确保音频文件已正确提取")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="MELD数据集采样程序")
    parser.add_argument(
        "--dataset_path", type=str, required=True, help="MELD数据集根目录路径"
    )
    parser.add_argument(
        "--emotion",
        type=str,
        required=True,
        help="要采样的情绪类别",
    )
    parser.add_argument("--num_samples", type=int, required=True, help="采样数量")
    parser.add_argument("--save_path", type=str, required=True, help="保存路径")
    parser.add_argument("--verbose", "-v", action="store_true", help="详细输出")

    args = parser.parse_args()

    # 设置日志
    logging.basicConfig(
        level=logging.INFO if args.verbose else logging.WARNING,
        format="%(asctime)s - %(levelname)s - %(message)s",
    )

    try:
        # 验证数据集路径
        dataset_path = Path(args.dataset_path)
        if not dataset_path.exists():
            print(f"错误: 数据集路径不存在: {args.dataset_path}")
            return 1

        if not dataset_path.is_dir():
            print(f"错误: 数据集路径不是目录: {args.dataset_path}")
            return 1

        # 初始化数据集
        print(f"正在加载MELD数据集: {args.dataset_path}")
        dataset = MELDDataset(args.dataset_path)

        # 验证情绪类别
        available_emotions = dataset.get_available_emotions()
        if args.emotion not in available_emotions:
            print(f"错误: 不支持的情绪类别: {args.emotion}")
            print(f"可用的情绪类别: {', '.join(available_emotions)}")
            return 1

        # 显示统计信息
        emotion_stats = dataset.get_emotion_stats()
        print(f"\n数据集统计信息:")
        for emotion, count in emotion_stats.items():
            print(f"  {emotion}: {count} 个样本")

        # 执行采样
        print(f"\n正在采样 {args.emotion} 情绪的 {args.num_samples} 个样本...")
        sampled_files = dataset.sample_emotion(args.emotion, args.num_samples)

        # 保存结果
        save_samples(
            sampled_files,
            args.emotion,
            args.save_path,
            args.dataset_path,
            args.num_samples,
        )

        return 0

    except Exception as e:
        print(f"错误: {e}")
        if args.verbose:
            import traceback

            traceback.print_exc()
        return 1


if __name__ == "__main__":
    exit(main())
