#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Audio sampling program for M3ED dataset.
Extracts audio files based on emotion labels from JSON metadata.
"""

import json
import os
import argparse
import shutil
import random
import logging
import csv
from pathlib import Path
from typing import Dict, List, Tuple, Optional


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Sample audio files by emotion from M3ED dataset"
    )
    parser.add_argument(
        "--emotion",
        required=True,
        help="Target emotion to sample (Happy, Neutral, Sad, Disgust, Anger, Fear, Surprise)",
    )
    parser.add_argument(
        "--output_dir", required=True, help="Output directory for sampled files"
    )
    parser.add_argument(
        "--dataset_dir", required=True, help="Directory containing WAV files"
    )
    parser.add_argument(
        "--sample_count", type=int, required=True, help="Number of samples to extract"
    )
    parser.add_argument("--meta_json", required=True, help="Path to JSON metadata file")

    return parser.parse_args()


def setup_logging():
    """Setup logging configuration."""
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler("sampling.log", encoding="utf-8"),
        ],
    )


def load_metadata(json_path: str) -> Dict:
    """Load metadata from JSON file."""
    try:
        with open(json_path, "r", encoding="utf-8") as f:
            metadata = json.load(f)
        logging.info(f"Successfully loaded metadata from {json_path}")
        return metadata
    except FileNotFoundError:
        logging.error(f"Metadata file not found: {json_path}")
        raise
    except json.JSONDecodeError as e:
        logging.error(f"Invalid JSON format in {json_path}: {e}")
        raise


def scan_wav_files(dataset_dir: str) -> List[str]:
    """Scan directory for WAV files."""
    wav_files = []
    try:
        for root, dirs, files in os.walk(dataset_dir):
            for file in files:
                if file.lower().endswith(".wav"):
                    wav_files.append(os.path.join(root, file))
        logging.info(f"Found {len(wav_files)} WAV files in {dataset_dir}")
        return wav_files
    except Exception as e:
        logging.error(f"Error scanning directory {dataset_dir}: {e}")
        raise


def parse_wav_filename(filename: str) -> Optional[Tuple[str, str, str, str]]:
    """Parse WAV filename to extract speaker, tv_show, episode, dialog."""
    # Remove path and extension
    basename = os.path.basename(filename)
    name_without_ext = os.path.splitext(basename)[0]

    # Split by underscore
    parts = name_without_ext.split("_")
    if len(parts) < 4:
        logging.warning(f"Cannot parse filename: {filename}")
        return None

    speaker = parts[0]
    tv_show = parts[1]

    # The episode should be tv_show + "_" + episode_number
    # e.g., "shaonianpai_1" from "A_shaonianpai_1_1.wav"
    episode = f"{tv_show}_{parts[2]}"

    # The dialog should be episode + "_" + dialog_number
    # e.g., "shaonianpai_1_1" from "A_shaonianpai_1_1.wav"
    dialog = f"{episode}_{parts[3]}"

    return speaker, tv_show, episode, dialog


def match_wav_to_metadata(wav_files: List[str], metadata: Dict) -> List[Dict]:
    """Match WAV files with metadata and extract emotion information."""
    matched_files = []

    for wav_file in wav_files:
        parsed = parse_wav_filename(wav_file)
        if not parsed:
            continue

        speaker, tv_show, episode, dialog = parsed

        # Search in metadata
        try:
            tv_show_data = metadata.get(tv_show, {})
            episode_data = tv_show_data.get(episode, {})
            dialog_data = episode_data.get("Dialog", {}).get(dialog, {})

            if not dialog_data:
                logging.debug(
                    f"No metadata found for {wav_file} (tv_show={tv_show}, episode={episode}, dialog={dialog})"
                )
                continue

            # Extract emotion
            emo_annotation = dialog_data.get("EmoAnnotation", {})
            emotion = emo_annotation.get("final_main_emo", "unknown")

            # Extract speaker info
            speaker_info = episode_data.get("SpeakerInfo", {}).get(speaker, {})
            speaker_name = speaker_info.get("Name", "unknown")
            gender = speaker_info.get("Gender", "unknown")

            file_info = {
                "file_path": wav_file,
                "speaker": speaker,
                "tv_show": tv_show,
                "episode": episode,
                "dialog": dialog,
                "emotion": emotion,
                "gender": gender,
                "speaker_name": speaker_name,
            }
            matched_files.append(file_info)

        except Exception as e:
            logging.warning(f"Error processing {wav_file}: {e}")
            continue

    logging.info(f"Successfully matched {len(matched_files)} files with metadata")
    return matched_files


def filter_by_emotion(matched_files: List[Dict], target_emotion: str) -> List[Dict]:
    """Filter files by target emotion."""
    filtered_files = [f for f in matched_files if f["emotion"] == target_emotion]
    logging.info(f"Found {len(filtered_files)} files with emotion '{target_emotion}'")
    return filtered_files


def sample_by_emotion(filtered_files: List[Dict], sample_count: int) -> List[Dict]:
    """Sample files by emotion, taking all if count is insufficient."""
    if len(filtered_files) <= sample_count:
        logging.info(
            f"Sample count ({sample_count}) >= available files ({len(filtered_files)}), taking all files"
        )
        return filtered_files
    else:
        sampled_files = random.sample(filtered_files, sample_count)
        logging.info(
            f"Randomly sampled {sample_count} files from {len(filtered_files)} available"
        )
        return sampled_files


def create_output_dirs(output_dir: str, emotion: str):
    """Create output directories."""
    emotion_dir = os.path.join(output_dir, emotion)
    os.makedirs(emotion_dir, exist_ok=True)
    logging.info(f"Created output directory: {emotion_dir}")


def copy_wav_files(sampled_files: List[Dict], output_dir: str, emotion: str):
    """Copy WAV files to output directory."""
    emotion_dir = os.path.join(output_dir, emotion)
    copied_files = []

    for file_info in sampled_files:
        source_path = file_info["file_path"]
        filename = os.path.basename(source_path)
        dest_path = os.path.join(emotion_dir, filename)

        try:
            shutil.copy2(source_path, dest_path)
            copied_files.append(filename)
            logging.debug(f"Copied {filename}")
        except Exception as e:
            logging.error(f"Error copying {source_path}: {e}")

    logging.info(f"Successfully copied {len(copied_files)} files to {emotion_dir}")
    return copied_files


def generate_csv(sampled_files: List[Dict], output_dir: str, emotion: str):
    """Generate CSV file with metadata."""
    csv_path = os.path.join(output_dir, f"{emotion}.csv")

    with open(csv_path, "w", newline="", encoding="utf-8") as csvfile:
        fieldnames = [
            "dataset_name",
            "wav_filename",
            "emotion_label",
            "gender",
            "speaker_id",
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for file_info in sampled_files:
            row = {
                "dataset_name": "M3ED",
                "wav_filename": os.path.basename(file_info["file_path"]),
                "emotion_label": file_info["emotion"],
                "gender": file_info["gender"],
                "speaker_id": file_info["speaker_name"],
            }
            writer.writerow(row)

    logging.info(f"Generated CSV file: {csv_path}")


def main():
    """Main function."""
    # Setup logging
    setup_logging()

    # Parse arguments
    args = parse_arguments()

    logging.info("Starting audio sampling process")
    logging.info(f"Target emotion: {args.emotion}")
    logging.info(f"Output directory: {args.output_dir}")
    logging.info(f"Dataset directory: {args.dataset_dir}")
    logging.info(f"Sample count: {args.sample_count}")
    logging.info(f"Metadata file: {args.meta_json}")

    # Load metadata
    metadata = load_metadata(args.meta_json)

    # Scan WAV files
    wav_files = scan_wav_files(args.dataset_dir)

    # Match WAV files with metadata
    matched_files = match_wav_to_metadata(wav_files, metadata)

    # Filter by emotion
    filtered_files = filter_by_emotion(matched_files, args.emotion)

    if not filtered_files:
        logging.error(f"No files found with emotion '{args.emotion}'")
        return

    # Sample files
    sampled_files = sample_by_emotion(filtered_files, args.sample_count)

    # Create output directories
    create_output_dirs(args.output_dir, args.emotion)

    # Copy WAV files
    copy_wav_files(sampled_files, args.output_dir, args.emotion)

    # Generate CSV
    generate_csv(sampled_files, args.output_dir, args.emotion)

    # Final statistics
    logging.info("=" * 50)
    logging.info("SAMPLING COMPLETED")
    logging.info("=" * 50)
    logging.info(f"Total WAV files scanned: {len(wav_files)}")
    logging.info(f"Files matched with metadata: {len(matched_files)}")
    logging.info(f"Files with emotion '{args.emotion}': {len(filtered_files)}")
    logging.info(f"Files sampled: {len(sampled_files)}")
    logging.info(f"Output directory: {os.path.join(args.output_dir, args.emotion)}")
    logging.info(f"CSV file: {os.path.join(args.output_dir, args.emotion)}.csv")


if __name__ == "__main__":
    main()
