#!/usr/bin/env python3
"""
Emo-Emilia Emotion Sampling Tool

Samples entries from Emo-Emilia-ALL.jsonl by specified emotion and saves to CSV format.
Also copies corresponding audio files to output directory.
"""

import json
import csv
import argparse
import os
import shutil
import random
from pathlib import Path


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Sample emotion-specific entries from Emo-Emilia dataset"
    )
    parser.add_argument(
        "--emotion",
        required=True,
        choices=[
            "angry",
            "happy",
            "fearful",
            "surprised",
            "neutral",
            "sad",
            "disgusted",
        ],
        help="Emotion category to sample",
    )
    parser.add_argument(
        "--label_file", required=True, help="Path to Emo-Emilia-ALL.jsonl file"
    )
    parser.add_argument(
        "--dataset_root", required=True, help="Root directory of the dataset"
    )
    parser.add_argument(
        "--output_dir", required=True, help="Output directory for CSV and audio files"
    )
    parser.add_argument(
        "--sample_num", type=int, required=True, help="Number of samples to extract"
    )
    return parser.parse_args()


def extract_speaker_id(index):
    """Extract speaker ID from index (first three parts)."""
    parts = index.split("_")
    if len(parts) >= 3:
        return "_".join(parts[:3])
    return "unknown"


def main():
    """Main function to sample emotion-specific entries."""
    args = parse_arguments()

    # Validate inputs
    if not os.path.exists(args.label_file):
        print(f"Error: Label file not found at {args.label_file}")
        return

    # Create output directories
    emotion_output_dir = os.path.join(args.output_dir, args.emotion)
    os.makedirs(emotion_output_dir, exist_ok=True)

    # Read and filter data
    filtered_entries = []
    with open(args.label_file, "r", encoding="utf-8") as f:
        for line in f:
            try:
                entry = json.loads(line.strip())
                if entry.get("emotion") == args.emotion:
                    filtered_entries.append(entry)
            except json.JSONDecodeError:
                print(f"Warning: Skipping invalid JSON line")
                continue

    print(f"Found {len(filtered_entries)} entries with emotion '{args.emotion}'")

    if len(filtered_entries) < args.sample_num:
        print(f"Warning: Only {len(filtered_entries)} entries available, sampling all")
        sampled_entries = filtered_entries
    else:
        sampled_entries = random.sample(filtered_entries, args.sample_num)

    # Prepare CSV data
    csv_data = []
    copied_files = 0

    for entry in sampled_entries:
        # Extract speaker ID from index
        speaker_id = extract_speaker_id(entry["index"])

        # Get wav filename from path
        wav_filename = os.path.basename(entry["wav"])

        # Prepare CSV row
        csv_row = {
            "dataset_name": "Emo-Emilia",
            "wav_filename": wav_filename,
            "emotion_label": entry["emotion"],
            "gender": "unknown",
            "speaker_id": speaker_id,
        }
        csv_data.append(csv_row)

        # Copy audio file
        source_wav = os.path.join(args.dataset_root, entry["wav"].lstrip("./"))
        target_wav = os.path.join(emotion_output_dir, wav_filename)

        try:
            if os.path.exists(source_wav):
                shutil.copy2(source_wav, target_wav)
                copied_files += 1
            else:
                print(f"Warning: Audio file not found: {source_wav}")
        except Exception as e:
            print(f"Error copying {source_wav}: {e}")

    # Save CSV file
    csv_path = os.path.join(args.output_dir, f"{args.emotion}.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        fieldnames = [
            "dataset_name",
            "wav_filename",
            "emotion_label",
            "gender",
            "speaker_id",
        ]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(csv_data)

    print(f"Successfully sampled {len(sampled_entries)} entries")
    print(f"Copied {copied_files} audio files to {emotion_output_dir}")
    print(f"Saved CSV file to {csv_path}")


if __name__ == "__main__":
    main()
