"""
BMD-HS Dataset File Processor

This script processes downloaded BMD-HS dataset files by:
1. Reading file_link_table.csv to get file mapping information
2. Copying files from download_dir/<filename> to processed_dir/<rename_filename>

The file_link_table.csv contains:
- raw: Original filename from the dataset (e.g., "raw/MD_001_sup_Mit.wav")
  Note: The "raw/" prefix is ignored, files are read directly from download_dir
- rename: New filename for processed data (e.g., "rename/00001.wav")
- Additional metadata columns (patient_id, labels, splits, etc.)

Usage:
    python process_bmdhs.py --download_dir /path/to/download --processed_dir /path/to/processed

Arguments:
    --download_dir: Directory containing downloaded files (files should be directly in this directory)
    --processed_dir: Directory where processed files will be saved
    --csv_file: Path to file_link_table.csv (default: src/prep/dataset/file_link_table.csv)
    --skip_existing: Skip copying if destination file already exists
    --verify: Verify file integrity after copying
"""

import argparse
import csv
import os
import shutil
import sys
from pathlib import Path
from typing import Dict, List, Tuple


def read_file_link_table(csv_path: Path) -> List[Dict[str, str]]:
    """
    Read file_link_table.csv and return list of file mappings.

    Args:
        csv_path: Path to file_link_table.csv

    Returns:
        List of dictionaries containing file mapping information
    """
    if not csv_path.exists():
        raise FileNotFoundError(f"file_link_table.csv not found: {csv_path}")

    file_mappings = []
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            file_mappings.append(row)

    print(f"Loaded {len(file_mappings)} file mappings from {csv_path}")
    return file_mappings


def process_files(
    file_mappings: List[Dict[str, str]],
    download_dir: Path,
    processed_dir: Path,
    skip_existing: bool = False,
    verify: bool = False
) -> Tuple[int, int, int]:
    """
    Process files according to file_link_table mappings.

    Args:
        file_mappings: List of file mapping dictionaries
        download_dir: Source directory containing raw files
        processed_dir: Destination directory for processed files
        skip_existing: Skip if destination file already exists
        verify: Verify file size after copying

    Returns:
        Tuple of (success_count, skipped_count, error_count)
    """
    success_count = 0
    skipped_count = 0
    error_count = 0

    # Create processed directory
    processed_dir.mkdir(parents=True, exist_ok=True)

    print(f"\nProcessing {len(file_mappings)} files...")
    print(f"Source directory: {download_dir}")
    print(f"Destination directory: {processed_dir}")
    print("-" * 60)

    for idx, mapping in enumerate(file_mappings, 1):
        raw_path = mapping.get('raw', '')
        rename_path = mapping.get('rename', '')

        if not raw_path or not rename_path:
            print(f"[{idx}/{len(file_mappings)}] Error: Missing raw or rename path")
            error_count += 1
            continue

        # Remove 'raw/' prefix and get just the filename
        raw_filename = raw_path.replace('raw/', '')
        source_file = download_dir / raw_filename

        # Remove 'rename/' prefix and get destination path
        rename_filename = rename_path.replace('rename/', '')
        dest_file = processed_dir / rename_filename

        # Create destination subdirectories if needed
        dest_file.parent.mkdir(parents=True, exist_ok=True)

        # Check if source file exists
        if not source_file.exists():
            print(f"[{idx}/{len(file_mappings)}] Error: Source file not found: {raw_filename}")
            error_count += 1
            continue

        # Skip if destination exists and skip_existing is True
        if skip_existing and dest_file.exists():
            if idx % 100 == 0:  # Print progress every 100 files
                print(f"[{idx}/{len(file_mappings)}] Skipped (already exists): {rename_filename}")
            skipped_count += 1
            continue

        # Copy file
        try:
            shutil.copy2(source_file, dest_file)

            # Verify file size if requested
            if verify:
                source_size = source_file.stat().st_size
                dest_size = dest_file.stat().st_size
                if source_size != dest_size:
                    print(f"[{idx}/{len(file_mappings)}] Error: File size mismatch for {rename_filename}")
                    error_count += 1
                    continue

            if idx % 100 == 0:  # Print progress every 100 files
                print(f"[{idx}/{len(file_mappings)}] Copied: {raw_filename} -> {rename_filename}")

            success_count += 1

        except Exception as e:
            print(f"[{idx}/{len(file_mappings)}] Error copying {raw_filename}: {e}")
            error_count += 1

    return success_count, skipped_count, error_count


def generate_metadata_csv(
    file_mappings: List[Dict[str, str]],
    processed_dir: Path
) -> None:
    """
    Generate metadata CSV file with all information from file_link_table.

    Args:
        file_mappings: List of file mapping dictionaries
        processed_dir: Processed data directory
    """
    metadata_path = processed_dir / "metadata.csv"

    print(f"\nGenerating metadata file: {metadata_path}")

    with open(metadata_path, 'w', encoding='utf-8', newline='') as f:
        if file_mappings:
            # Get all column names from first mapping
            fieldnames = list(file_mappings[0].keys())
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()

            for mapping in file_mappings:
                # Update rename path to remove 'rename/' prefix
                updated_mapping = mapping.copy()
                if 'rename' in updated_mapping:
                    updated_mapping['rename'] = updated_mapping['rename'].replace('rename/', '')
                writer.writerow(updated_mapping)

    print(f"Metadata file created with {len(file_mappings)} entries")


def main():
    # Get the default CSV file path (same directory as this script)
    script_dir = Path(__file__).parent.resolve()
    default_csv_file = script_dir / "file_link_table.csv"

    parser = argparse.ArgumentParser(
        description="Process BMD-HS Dataset files using file_link_table.csv"
    )
    parser.add_argument(
        "--download_dir",
        type=str,
        required=True,
        help="Directory containing downloaded raw files"
    )
    parser.add_argument(
        "--processed_dir",
        type=str,
        required=True,
        help="Directory where processed files will be saved"
    )
    parser.add_argument(
        "--csv_file",
        type=str,
        default=str(default_csv_file),
        help=f"Path to file_link_table.csv (default: {default_csv_file})"
    )
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip copying if destination file already exists"
    )
    parser.add_argument(
        "--verify",
        action="store_true",
        help="Verify file integrity after copying"
    )

    args = parser.parse_args()

    # Setup paths
    download_dir = Path(args.download_dir).resolve()
    processed_dir = Path(args.processed_dir).resolve()
    csv_path = Path(args.csv_file).resolve()

    print("=" * 60)
    print("BMD-HS Dataset File Processor")
    print("=" * 60)
    print(f"CSV file: {csv_path}")
    print(f"Download directory: {download_dir}")
    print(f"Processed directory: {processed_dir}")
    print(f"Skip existing: {args.skip_existing}")
    print(f"Verify files: {args.verify}")
    print("=" * 60)
    print()

    # Check if download directory exists
    if not download_dir.exists():
        print(f"Error: Download directory not found: {download_dir}")
        print("\nPlease ensure the download directory exists and contains the dataset files")
        return 1

    # Read file mappings
    try:
        file_mappings = read_file_link_table(csv_path)
    except Exception as e:
        print(f"Error reading file_link_table.csv: {e}")
        return 1

    # Process files
    success, skipped, errors = process_files(
        file_mappings,
        download_dir,
        processed_dir,
        args.skip_existing,
        args.verify
    )

    # Generate metadata CSV
    generate_metadata_csv(file_mappings, processed_dir)

    # Print summary
    print("\n" + "=" * 60)
    print("Processing Summary")
    print("=" * 60)
    print(f"Total files:     {len(file_mappings)}")
    print(f"Successfully processed: {success}")
    print(f"Skipped:        {skipped}")
    print(f"Errors:         {errors}")
    print("=" * 60)

    if errors > 0:
        print(f"\nWarning: {errors} files could not be processed")
        return 1
    else:
        print("\nAll files processed successfully!")
        return 0


if __name__ == "__main__":
    sys.exit(main())
