#!/usr/bin/env python3
"""
Test script for the updated sample.py implementation
"""

import pandas as pd
import os
import tempfile
import shutil
from sample import normalize_emotion_label, normalize_gender, denormalize_emotion_label


def create_test_csv():
    """Create a test CSV file with the new format"""
    test_data = [
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0027.wav",
            "emotion_label": "neutral",
            "gender": "male",
            "speaker_id": 30,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0041.wav",
            "emotion_label": "neutral",
            "gender": "male",
            "speaker_id": 30,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0058.wav",
            "emotion_label": "happy",
            "gender": "female",
            "speaker_id": 45,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0071.wav",
            "emotion_label": "sad",
            "gender": "male",
            "speaker_id": 30,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0077.wav",
            "emotion_label": "angry",
            "gender": "female",
            "speaker_id": 45,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0080.wav",
            "emotion_label": "neutral",
            "gender": "male",
            "speaker_id": 30,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0085.wav",
            "emotion_label": "happy",
            "gender": "female",
            "speaker_id": 45,
        },
        {
            "dataset_name": "MSP-PODCAST-Publish-1.12",
            "wav_filename": "MSP-PODCAST_0001_0090.wav",
            "emotion_label": "sad",
            "gender": "male",
            "speaker_id": 30,
        },
    ]

    df = pd.DataFrame(test_data)
    return df


def test_normalization_functions():
    """Test the normalization functions"""
    print("Testing normalization functions...")

    # Test emotion normalization
    assert normalize_emotion_label("neutral") == "N"
    assert normalize_emotion_label("happy") == "H"
    assert normalize_emotion_label("sad") == "S"
    assert normalize_emotion_label("angry") == "A"
    print("✅ Emotion normalization tests passed")

    # Test gender normalization
    assert normalize_gender("male") == "Male"
    assert normalize_gender("female") == "Female"
    print("✅ Gender normalization tests passed")

    # Test emotion denormalization
    assert denormalize_emotion_label("N") == "neutral"
    assert denormalize_emotion_label("H") == "happy"
    print("✅ Emotion denormalization tests passed")


def test_csv_processing():
    """Test CSV processing with the new format"""
    print("\nTesting CSV processing...")

    # Create test data
    df = create_test_csv()

    # Test normalization
    df["gender"] = df["gender"].apply(normalize_gender)
    df["emotion_label"] = df["emotion_label"].apply(normalize_emotion_label)

    # Verify normalization worked
    assert "Male" in df["gender"].values
    assert "Female" in df["gender"].values
    assert "N" in df["emotion_label"].values
    assert "H" in df["emotion_label"].values
    assert "S" in df["emotion_label"].values
    assert "A" in df["emotion_label"].values

    print("✅ CSV processing tests passed")


def test_sample_script():
    """Test the actual sample script with a small dataset"""
    print("\nTesting sample script...")

    # Create temporary directory
    with tempfile.TemporaryDirectory() as temp_dir:
        # Create test CSV file
        df = create_test_csv()
        test_csv_path = os.path.join(temp_dir, "test_labels.csv")
        df.to_csv(test_csv_path, index=False)

        # Create output directory
        output_dir = os.path.join(temp_dir, "output")
        os.makedirs(output_dir)

        # Test running the script
        try:
            import subprocess

            result = subprocess.run(
                [
                    "python",
                    "sample.py",
                    "--emotion",
                    "neutral",
                    "--output_dir",
                    output_dir,
                    "--label_file",
                    test_csv_path,
                    "--sample_num",
                    "2",
                ],
                capture_output=True,
                text=True,
                cwd=os.path.dirname(os.path.abspath(__file__)),
            )

            if result.returncode == 0:
                print("✅ Sample script execution successful")

                # Check if output files were created
                neutral_csv = os.path.join(output_dir, "Neutral.csv")
                neutral_info = os.path.join(output_dir, "Neutral.info")

                if os.path.exists(neutral_csv):
                    print("✅ Output CSV file created")
                    # Check the content
                    output_df = pd.read_csv(neutral_csv)
                    print(f"   - Output contains {len(output_df)} samples")
                    print(f"   - Columns: {list(output_df.columns)}")
                else:
                    print("❌ Output CSV file not found")

                if os.path.exists(neutral_info):
                    print("✅ Info file created")
                else:
                    print("❌ Info file not found")

            else:
                print(f"❌ Sample script failed: {result.stderr}")

        except Exception as e:
            print(f"❌ Error running sample script: {e}")


def main():
    print("Starting tests for updated sample.py implementation...")

    test_normalization_functions()
    test_csv_processing()
    test_sample_script()

    print("\n🎉 All tests completed!")


if __name__ == "__main__":
    main()
