import pandas as pd
import argparse
import os
import random
import shutil
import json
from collections import defaultdict


# Global emotion mapping
emotion_map = {
    "A": "Angry",
    "S": "Sad",
    "H": "Happy",
    "U": "Surprise",
    "F": "Fear",
    "D": "Disgust",
    "C": "Contempt",
    "N": "Neutral",
    "O": "Other",
    "X": "None",
}

# Reverse emotion mapping from lowercase names to codes
reverse_emotion_map = {
    "angry": "A",
    "sad": "S",
    "happy": "H",
    "surprise": "U",
    "fear": "F",
    "disgust": "D",
    "contempt": "C",
    "neutral": "N",
    "other": "O",
    "none": "X",
}

# Gender normalization mapping
gender_map = {
    "male": "Male",
    "female": "Female",
}

# Emotions to process when "ALL" is specified (excluding O and X)
ALL_EMOTIONS = [
    "angry",
    "sad",
    "happy",
    "surprise",
    "fear",
    "disgust",
    "contempt",
    "neutral",
]


def normalize_emotion_label(emotion_label):
    """Convert lowercase emotion label to uppercase code"""
    return reverse_emotion_map.get(emotion_label.lower(), emotion_label)


def normalize_gender(gender):
    """Convert lowercase gender to uppercase"""
    return gender_map.get(gender.lower(), gender)


def denormalize_emotion_label(emotion_code):
    """Convert emotion code back to lowercase name"""
    emotion_name = emotion_map.get(emotion_code, emotion_code)
    return emotion_name.lower()


def parse_args():
    parser = argparse.ArgumentParser(
        description="Sample audio files of a specific emotion"
    )
    parser.add_argument(
        "--emotion",
        type=str,
        required=True,
        help="Target emotion class: angry, sad, happy, surprise, fear, disgust, contempt, neutral, other, none, or ALL for all emotions (except other and none)",
    )
    parser.add_argument(
        "--output_dir", type=str, required=True, help="Directory to save the output CSV"
    )
    parser.add_argument(
        "--label_file",
        type=str,
        default="labels_consensus.csv",
        help="Path to the labels_consensus.csv file",
    )
    parser.add_argument(
        "--sample_num", type=int, default=1000, help="Number of samples to extract"
    )
    parser.add_argument(
        "--dataset_root",
        type=str,
        help="Root directory of the MSP-Podcast dataset containing wav files",
    )
    return parser.parse_args()


def stratified_sample(df, sample_num):
    """
    Stratified sampling with priority:
    1. Sample count (ensure we get exactly sample_num samples)
    2. Speaker diversity (maximize number of unique speakers)
    3. Gender balance (try to maintain reasonable gender ratio)
    """
    # Group by speaker
    speaker_groups = defaultdict(list)
    for _, row in df.iterrows():
        speaker_groups[row["speaker_id"]].append(row)

    # Convert to list and shuffle for randomization
    speakers = list(speaker_groups.items())
    random.shuffle(speakers)

    # First pass: try to get one sample from each speaker
    sampled = []
    used_speakers = set()

    for speaker_id, rows in speakers:
        if len(sampled) >= sample_num:
            break
        # Take one sample from this speaker
        sample = random.choice(rows)
        sampled.append(sample)
        used_speakers.add(speaker_id)

    # If we still need more samples, take additional samples from speakers
    # who have more samples available
    remaining_needed = sample_num - len(sampled)

    if remaining_needed > 0:
        # Collect all remaining samples from used speakers
        additional_samples = []
        for speaker_id, rows in speaker_groups.items():
            if speaker_id in used_speakers:
                # Get samples that haven't been used yet
                used_samples = {
                    s["wav_filename"] for s in sampled if s["speaker_id"] == speaker_id
                }
                available_samples = [
                    r for r in rows if r["wav_filename"] not in used_samples
                ]
                additional_samples.extend(available_samples)

        # If we still need more, get samples from unused speakers
        if len(additional_samples) < remaining_needed:
            for speaker_id, rows in speaker_groups.items():
                if speaker_id not in used_speakers:
                    additional_samples.extend(rows)
                    if len(additional_samples) >= remaining_needed:
                        break

        # Randomly sample from additional samples
        if len(additional_samples) >= remaining_needed:
            additional_selected = random.sample(additional_samples, remaining_needed)
        else:
            # If still not enough, take all available
            additional_selected = additional_samples[:remaining_needed]

        sampled.extend(additional_selected)

    # Convert to DataFrame and ensure we have exactly sample_num samples
    sampled_df = pd.DataFrame(sampled)
    return sampled_df.head(sample_num)


def save_dataset_info(df, sampled_df, emotion, output_dir, sample_num, dataset_root):
    """Save dataset statistics to JSON file"""
    # Calculate statistics for original dataset
    total_count = len(df)
    if total_count > 0:
        speaker_count = df["speaker_id"].nunique()
        # Get unique speakers from original dataset
        original_speakers = sorted(df["speaker_id"].unique().tolist())
    else:
        speaker_count = 0
        original_speakers = []

    # Calculate statistics for sampled dataset
    sampled_total = len(sampled_df)
    if sampled_total > 0:
        sampled_speaker_count = sampled_df["speaker_id"].nunique()
        # Get unique speakers from sampled dataset
        sampled_speakers = sorted(sampled_df["speaker_id"].unique().tolist())
    else:
        sampled_speaker_count = 0
        sampled_speakers = []

    # Create JSON content
    json_content = {
        "dataset": "MSP-Podcast",
        "emotion": emotion,
        "total_samples": sampled_total,
        "speaker_count": sampled_speaker_count,
        "speakers": sampled_speakers,
        "sample_count_requested": sample_num,
        "dataset_dir": dataset_root if dataset_root else None,
    }

    # Save to .json file
    json_path = os.path.join(output_dir, f"{emotion}.json")
    with open(json_path, "w") as f:
        json.dump(json_content, f, indent=2)

    print(f"Saved dataset info to {json_path}")


def copy_wav_files(sampled_df, dataset_root, output_dir, emotion):
    """Copy wav files to output directory"""
    if not dataset_root:
        print("Warning: --dataset_root not specified, skipping wav file copying")
        return

    # Create wav subdirectory
    wav_dir = os.path.join(output_dir, f"{emotion}_wavs")
    if not os.path.exists(wav_dir):
        os.makedirs(wav_dir)

    copied_count = 0
    failed_count = 0

    for _, row in sampled_df.iterrows():
        # Get the wav file path from the dataframe
        wav_filename = row.get("wav_filename", "")
        if not wav_filename:
            continue

        # Construct full path to wav file
        wav_path = os.path.join(dataset_root, wav_filename)

        # Check if wav file exists
        if os.path.exists(wav_path):
            # Copy to output directory
            dest_path = os.path.join(wav_dir, os.path.basename(wav_filename))
            try:
                shutil.copy2(wav_path, dest_path)
                copied_count += 1
            except Exception as e:
                print(f"Failed to copy {wav_path}: {e}")
                failed_count += 1
        else:
            print(f"Warning: Wav file not found: {wav_path}")
            failed_count += 1

    print(f"Copied {copied_count} wav files to {wav_dir}")
    if failed_count > 0:
        print(f"Failed to copy {failed_count} wav files")


def process_single_emotion(df, emotion, sample_num, output_dir, dataset_root):
    """Process a single emotion and save results"""
    # Normalize emotion to code for filtering
    emotion_code = normalize_emotion_label(emotion)

    # Filter by emotion
    emotion_df = df[df["emotion_label"] == emotion_code]

    # Drop rows without gender or speaker info
    emotion_df = emotion_df[emotion_df["gender"].isin(["Male", "Female"])]
    emotion_df = emotion_df[emotion_df["speaker_id"].notnull()]

    # Print original dataset information
    print(f"\n=== Original Dataset Information for Emotion: {emotion} ===")
    total_count = len(emotion_df)
    if total_count > 0:
        gender_counts = emotion_df["gender"].value_counts()
        male_count = gender_counts.get("Male", 0)
        female_count = gender_counts.get("Female", 0)
        speaker_count = emotion_df["speaker_id"].nunique()
        male_percent = male_count / total_count * 100
        female_percent = female_count / total_count * 100

        print(f"Total samples: {total_count}")
        print(f"Male samples: {male_count} ({male_percent:.1f}%)")
        print(f"Female samples: {female_count} ({female_percent:.1f}%)")
        print(f"Unique speakers: {speaker_count}")

        # Print speaker distribution
        speaker_counts = emotion_df["speaker_id"].value_counts()
        print(
            f"Speaker distribution: min={speaker_counts.min()}, max={speaker_counts.max()}, mean={speaker_counts.mean():.1f}"
        )
    else:
        print("No samples found for this emotion")

    print("=" * 60)

    if len(emotion_df) < sample_num:
        print(
            f"Warning: only {len(emotion_df)} samples available for emotion '{emotion}'"
        )

    # Sample with balanced gender and diverse speakers
    sampled_df = stratified_sample(emotion_df, sample_num)

    emotion_name = emotion_map[emotion_code]

    # Save CSV
    output_path = os.path.join(output_dir, f"{emotion_name}.csv")
    sampled_df.to_csv(output_path, index=False)
    print(f"Saved {len(sampled_df)} samples to {output_path}")

    # Save dataset info
    save_dataset_info(
        emotion_df, sampled_df, emotion_name, output_dir, sample_num, dataset_root
    )

    # Copy wav files if dataset root is specified
    copy_wav_files(sampled_df, dataset_root, output_dir, emotion_name)


def main():
    args = parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Load label CSV
    df = pd.read_csv(args.label_file)

    # Normalize gender and emotion labels
    df["gender"] = df["gender"].apply(normalize_gender)
    df["emotion_label"] = df["emotion_label"].apply(normalize_emotion_label)

    # Drop rows without gender or speaker info
    df = df[df["gender"].isin(["Male", "Female"])]
    df = df[df["speaker_id"].notnull()]

    if args.emotion == "ALL":
        print(f"Processing all emotions: {ALL_EMOTIONS}")
        for emotion in ALL_EMOTIONS:
            print(
                f"\nProcessing emotion: {emotion} ({emotion_map[normalize_emotion_label(emotion)]})"
            )
            process_single_emotion(
                df, emotion, args.sample_num, args.output_dir, args.dataset_root
            )
        print(f"\nCompleted processing all {len(ALL_EMOTIONS)} emotions")
    else:
        # Process single emotion
        normalized_emotion = normalize_emotion_label(args.emotion)
        if normalized_emotion not in emotion_map:
            print(
                f"Error: Invalid emotion '{args.emotion}'. Valid emotions: {list(reverse_emotion_map.keys())} + ALL"
            )
            return

        process_single_emotion(
            df, args.emotion, args.sample_num, args.output_dir, args.dataset_root
        )


if __name__ == "__main__":
    main()
