#!/usr/bin/env python3
"""
IEMOCAP数据集高质量数据筛选程序

该程序从IEMOCAP数据集中筛选出情绪判断一致的音频片段，
并验证对应的wav文件是否存在，最终输出CSV格式的结果。

注意：本程序不考虑说话人平衡或性别比例，只要满足情绪判断一致
且对应的wav文件存在，就会被加入到输出文件中。

对于恐惧情绪(Fear)，只需要至少有一个标注者标注为恐惧即可，
不需要三个标注者都一致。其他情绪仍需要三个标注者都一致。
"""

import os
import json
import csv
import argparse
import glob
import random
from collections import defaultdict
from typing import Dict, List, Set, Tuple
import re
import shutil


def parse_annotation_file(file_path: str) -> Dict[str, Set[str]]:
    """
    解析单个标注文件，返回片段ID到情绪集合的映射

    Args:
        file_path: 标注文件路径

    Returns:
        Dict[str, Set[str]]: 片段ID到情绪集合的映射
    """
    annotations = {}
    pattern = re.compile(r"^(\S+)\s*:(.*)$")
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                m = pattern.match(line)
                if not m:
                    continue
                segment_id = m.group(1)
                emotions_part = m.group(2)
                emotions = set()
                emotion_labels = emotions_part.split(";")
                for emotion_label in emotion_labels:
                    emotion_label = emotion_label.strip()
                    if emotion_label and emotion_label != "()":
                        emotion = emotion_label.split("(")[0].strip(":").strip()
                        if emotion:
                            emotions.add(emotion)
                if emotions:
                    annotations[segment_id] = emotions
    except Exception as e:
        print(f"Warning: 无法解析文件 {file_path}: {e}")
    return annotations


def find_annotation_files(dataset_dir: str) -> List[str]:
    """
    查找所有标注文件

    Args:
        dataset_dir: 数据集根目录

    Returns:
        List[str]: 标注文件路径列表
    """
    annotation_files = []

    # 查找所有Session目录下的EmoEvaluation/Categorical文件夹
    for session_dir in glob.glob(os.path.join(dataset_dir, "Session*")):
        categorical_dir = os.path.join(
            session_dir, "dialog", "EmoEvaluation", "Categorical"
        )
        if os.path.exists(categorical_dir):
            cat_files = glob.glob(os.path.join(categorical_dir, "*_cat.txt"))
            annotation_files.extend(cat_files)

    return annotation_files


def group_annotations_by_speaker(
    annotation_files: List[str],
) -> Dict[str, List[Dict[str, Set[str]]]]:
    """
    按说话人分组标注文件

    Args:
        annotation_files: 标注文件路径列表

    Returns:
        Dict[str, List[Dict[str, Set[str]]]]: 说话人ID到标注列表的映射
    """
    speaker_annotations = defaultdict(list)

    for file_path in annotation_files:
        # 从文件名提取说话人ID，例如: Ses01F_impro01_e2_cat.txt -> Ses01F_impro01
        filename = os.path.basename(file_path)
        parts = filename.split("_")
        if len(parts) >= 3:
            speaker_id = f"{parts[0]}_{parts[1]}"
            annotations = parse_annotation_file(file_path)
            if annotations:
                speaker_annotations[speaker_id].append(annotations)

    return speaker_annotations


def check_consistency(
    annotations_list: List[Dict[str, Set[str]]], segment_id: str
) -> Tuple[bool, str]:
    """
    检查某个片段在多个标注者之间的一致性

    Args:
        annotations_list: 多个标注者的标注列表
        segment_id: 片段ID

    Returns:
        Tuple[bool, str]: (是否一致, 一致的情绪标签)
    """
    if len(annotations_list) < 3:
        return False, ""

    # 收集所有标注者对该片段的情绪标注
    all_emotions = []
    for annotations in annotations_list:
        if segment_id in annotations:
            all_emotions.append(annotations[segment_id])

    if len(all_emotions) < 3:
        return False, ""

    # 检查是否有至少一个情绪在所有标注者中都存在
    for emotion in all_emotions[0]:
        if all(emotion in emotions for emotions in all_emotions):
            return True, emotion

    return False, ""


def verify_wav_file_exists(dataset_dir: str, segment_id: str) -> bool:
    """
    验证对应的wav文件是否存在，兼容Session*目录

    Args:
        dataset_dir: 数据集根目录
        segment_id: 片段ID，例如: Ses01F_impro01_M013 或 Ses01F_script01_3_F011

    Returns:
        bool: wav文件是否存在
    """
    parts = segment_id.split("_")
    if len(parts) < 3:
        return False

    # 找到性别标识符的位置（F或M开头的部分）
    gender_part_index = -1
    for i, part in enumerate(parts):
        if part.startswith("F") or part.startswith("M"):
            gender_part_index = i
            break

    if gender_part_index == -1:
        return False

    # 说话人ID是从开始到性别标识符之前的所有部分
    speaker_parts = parts[:gender_part_index]
    speaker_part = "_".join(speaker_parts)
    wav_filename = f"{segment_id}.wav"

    # 遍历所有Session*目录
    for session_dir in os.listdir(dataset_dir):
        if session_dir.startswith("Session") or session_dir.startswith("Ses"):
            wav_path = os.path.join(
                dataset_dir, session_dir, "sentences", "wav", speaker_part, wav_filename
            )
            if os.path.exists(wav_path):
                return True
    return False


def check_consistency_for_emotion(annotations_list, segment_id, target_emotion):
    """
    检查所有标注者是否都标注该片段为目标情绪

    对于恐惧情绪(Fear)，不需要三个标注者都一致，只要至少有一个标注者标注为恐惧即可
    """
    if len(annotations_list) < 3:
        return False

    # 对于恐惧情绪，使用特殊处理逻辑
    if target_emotion.lower() == "fear":
        return check_fear_emotion_consistency(
            annotations_list, segment_id, target_emotion
        )

    # 其他情绪使用原有的严格一致性检查
    all_emotions = []
    for annotations in annotations_list:
        if segment_id in annotations:
            all_emotions.append(annotations[segment_id])
    if len(all_emotions) < 3:
        return False
    # 所有标注者都必须且仅包含该情绪
    for emotions in all_emotions:
        if target_emotion not in emotions or len(emotions) != 1:
            return False
    return True


def check_fear_emotion_consistency(annotations_list, segment_id, target_emotion):
    """
    检查恐惧情绪的特殊一致性规则

    对于恐惧情绪，只要至少有一个标注者标注为恐惧即可，不需要三个标注者都一致
    """
    fear_annotations = 0
    total_annotations = 0

    for annotations in annotations_list:
        if segment_id in annotations:
            total_annotations += 1
            emotions = annotations[segment_id]
            # 检查是否包含恐惧情绪
            if target_emotion in emotions:
                fear_annotations += 1

    # 至少有一个标注者标注为恐惧，且至少有3个标注者对该片段进行了标注
    return fear_annotations >= 1 and total_annotations >= 3


def get_session_count(dataset_dir: str) -> int:
    """获取数据集中的session数量"""
    session_count = 0
    for item in os.listdir(dataset_dir):
        if item.startswith("Session") or item.startswith("Ses"):
            session_count += 1
    return session_count


def get_gender_from_segment(segment_id: str) -> str:
    """从片段ID中提取性别信息"""
    parts = segment_id.split("_")
    if len(parts) >= 3:
        last_part = parts[-1]
        if last_part.startswith("F"):
            return "F"
        elif last_part.startswith("M"):
            return "M"
    return "Unknown"


def filter_high_quality_data_with_limit(
    dataset_dir: str, dataset_name: str, target_emotion: str, sample_count: int
) -> list:
    """
    筛选高质量数据并限制数量

    注意：此函数不考虑说话人平衡或性别比例，只要满足情绪判断一致
    且对应的wav文件存在，就会被加入到候选列表中。

    对于恐惧情绪(Fear)，只需要至少有一个标注者标注为恐惧即可，
    不需要三个标注者都一致。其他情绪仍需要三个标注者都一致。
    """
    print(f"正在扫描数据集: {dataset_dir}")
    annotation_files = find_annotation_files(dataset_dir)
    print(f"找到 {len(annotation_files)} 个标注文件")
    speaker_annotations = group_annotations_by_speaker(annotation_files)
    print(f"找到 {len(speaker_annotations)} 个说话人")

    # 收集所有符合条件的片段（不考虑说话人、性别比例）
    all_valid_segments = []
    for speaker_id, annotations_list in speaker_annotations.items():
        all_segments = set()
        for annotations in annotations_list:
            all_segments.update(annotations.keys())
        for segment_id in all_segments:
            # 检查三个标注者是否都标注为目标情绪
            if check_consistency_for_emotion(
                annotations_list, segment_id, target_emotion
            ):
                if verify_wav_file_exists(dataset_dir, segment_id):
                    all_valid_segments.append(segment_id)
                    print(f"  ✓ {segment_id}: {target_emotion}")
                else:
                    print(f"  ✗ {segment_id}: wav文件不存在")

    print(f"找到 {len(all_valid_segments)} 个符合条件的片段")

    if len(all_valid_segments) == 0:
        print(f"警告: 没有找到符合条件的片段")
        return []

    if len(all_valid_segments) < sample_count:
        print(
            f"警告: 符合条件的片段数量({len(all_valid_segments)})少于要求的数量({sample_count})，将使用所有可用片段"
        )
        selected_segments = all_valid_segments
    else:
        # 随机选择指定数量的片段（不考虑说话人、性别平衡）
        random.seed(42)  # 设置随机种子以确保结果可重现
        selected_segments = random.sample(all_valid_segments, sample_count)

    print(f"选择了 {len(selected_segments)} 个片段")

    # 构建最终结果
    high_quality_data = []
    for segment_id in selected_segments:
        gender = get_gender_from_segment(segment_id)
        data_item = {
            "dataset": dataset_name,
            "wavfile": f"{segment_id}.wav",
            "emotion": target_emotion,
            "gender": gender,
        }
        high_quality_data.append(data_item)

    return high_quality_data


def copy_selected_wavs_and_labels(
    high_quality_data, dataset_dir, copy_dir, copy_labels=True
):
    """将筛选出的音频文件及其标注文件拷贝到指定目录"""
    if not os.path.exists(copy_dir):
        os.makedirs(copy_dir)

    # 记录已拷贝的标注文件，避免重复拷贝
    copied_labels = set()

    for item in high_quality_data:
        wavfile = item["wavfile"]
        segment_id = wavfile[:-4]  # 去掉.wav
        parts = segment_id.split("_")
        # 找到性别标识符的位置
        gender_part_index = -1
        for i, part in enumerate(parts):
            if part.startswith("F") or part.startswith("M"):
                gender_part_index = i
                break
        if gender_part_index == -1:
            continue
        speaker_parts = parts[:gender_part_index]
        speaker_part = "_".join(speaker_parts)
        # 遍历所有Session*目录
        found = False
        for session_dir in os.listdir(dataset_dir):
            if session_dir.startswith("Session") or session_dir.startswith("Ses"):
                wav_path = os.path.join(
                    dataset_dir, session_dir, "sentences", "wav", speaker_part, wavfile
                )
                if os.path.exists(wav_path):
                    shutil.copy(wav_path, os.path.join(copy_dir, wavfile))
                    found = True
                    # 拷贝标注文件（只用speaker_part作为前缀）
                    if copy_labels:
                        label_dir = os.path.join(
                            dataset_dir,
                            session_dir,
                            "dialog",
                            "EmoEvaluation",
                            "Categorical",
                        )
                        if os.path.exists(label_dir):
                            for fname in os.listdir(label_dir):
                                if fname.startswith(speaker_part) and fname.endswith(
                                    "_cat.txt"
                                ):
                                    if fname not in copied_labels:  # 避免重复拷贝
                                        src_label = os.path.join(label_dir, fname)
                                        dst_label = os.path.join(copy_dir, fname)
                                        shutil.copy(src_label, dst_label)
                                        copied_labels.add(fname)
                    break
        if not found:
            print(f"警告: 未找到音频文件 {wavfile}，未拷贝")


def main():
    parser = argparse.ArgumentParser(description="IEMOCAP数据集高质量数据筛选程序")
    parser.add_argument("--dataset_dir", help="数据集根目录路径")
    parser.add_argument("--output_dir", help="输出目录路径")
    parser.add_argument(
        "--dataset-name", default="IEMOCAP", help="数据集名称 (默认: IEMOCAP)"
    )
    parser.add_argument(
        "--emotion", "-e", required=True, help="筛选目标情绪类别（如Anger）"
    )
    parser.add_argument(
        "--sample-count", "-n", type=int, required=True, help="筛选的片段数量"
    )
    parser.add_argument("--no-label-file", action="store_true", help="不复制标注文件")
    args = parser.parse_args()

    if not os.path.exists(args.dataset_dir):
        print(f"错误: 数据集目录不存在: {args.dataset_dir}")
        return 1

    # 验证采样数量
    if args.sample_count <= 0:
        print(f"错误: 采样数量必须大于0")
        return 1

    print(f"开始筛选高质量数据...")
    print(f"数据集目录: {args.dataset_dir}")
    print(f"输出目录: {args.output_dir}")
    print(f"数据集名称: {args.dataset_name}")
    print(f"筛选情绪: {args.emotion}")
    print(f"采样数量: {args.sample_count}")
    print("-" * 50)

    high_quality_data = filter_high_quality_data_with_limit(
        args.dataset_dir, args.dataset_name, args.emotion, args.sample_count
    )

    if not high_quality_data:
        print("没有找到符合条件的数据")
        return 1

    # 创建输出目录结构
    os.makedirs(args.output_dir, exist_ok=True)
    emotion_wavs_dir = os.path.join(args.output_dir, f"{args.emotion}_wavs")
    os.makedirs(emotion_wavs_dir, exist_ok=True)

    # 保存主CSV结果
    main_csv_path = os.path.join(args.output_dir, f"{args.emotion}.csv")
    with open(main_csv_path, "w", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["dataset", "wavfile", "emotion", "gender"])
        for item in high_quality_data:
            writer.writerow(
                [item["dataset"], item["wavfile"], item["emotion"], item["gender"]]
            )

    # 计算统计信息
    speaker_count = set()

    for item in high_quality_data:
        # 从wavfile提取说话人信息
        wavfile = item["wavfile"]
        segment_id = wavfile[:-4]  # 去掉.wav
        parts = segment_id.split("_")
        # 找到性别标识符的位置
        gender_part_index = -1
        for i, part in enumerate(parts):
            if part.startswith("F") or part.startswith("M"):
                gender_part_index = i
                break
        if gender_part_index != -1:
            speaker_parts = parts[:gender_part_index]
            speaker_id = "_".join(speaker_parts)
            speaker_count.add(speaker_id)

    # 构建统计信息
    stats_data = {
        "dataset": args.dataset_name,
        "emotion": args.emotion,
        "total_samples": len(high_quality_data),
        "speaker_count": len(speaker_count),
        "speakers": list(speaker_count),
        "sample_count_requested": args.sample_count,
        "dataset_dir": args.dataset_dir,
    }

    # 保存统计信息JSON
    stats_json_path = os.path.join(args.output_dir, f"{args.emotion}_stats.json")
    with open(stats_json_path, "w", encoding="utf-8") as f:
        json.dump(stats_data, f, ensure_ascii=False, indent=2)

    print("-" * 50)
    print(f"筛选完成！")
    print(f"找到 {len(high_quality_data)} 个高质量音频片段")
    print(f"主结果已保存到: {main_csv_path}")
    print(f"统计信息已保存到: {stats_json_path}")

    # 拷贝音频和标注文件
    print(f"开始拷贝音频文件到: {emotion_wavs_dir}")
    copy_selected_wavs_and_labels(
        high_quality_data, args.dataset_dir, emotion_wavs_dir, not args.no_label_file
    )
    print(f"文件拷贝完成！")

    return 0


if __name__ == "__main__":
    exit(main())
