"""
Convert many audio files to wav in parallel using ffmpeg.

Usage:

# on a small local sample
python scripts/convert_audio_to_wav.py \
    --input-dir datasets/testdata \
    --input-extension ".mp3" \
    --output-dir datasets/tmp \
    --runner DirectRunner

# on a small GCS sample
python scripts/convert_audio_to_wav.py \
    --input-dir gs://music2text/datasets/testdata-mp3/ \
    --input-extension ".mp3" \
    --output-dir gs://music2text/datasets/testoutputs-mp3towav/ \
    --runner DirectRunner

# on FMA full
python scripts/convert_audio_to_wav.py \
    --input-dir "gs://music2text/datasets/fma/full/mp3/" \
    --input-extension ".mp3" \
    --output-dir gs://music2text/datasets/fma/full/wav/ \
    --runner DataflowRunner \
    --num-workers 512

# magnatagatune (train)
python scripts/convert_audio_to_wav.py \
    --input-dir "gs://music2text/datasets/magnatagatune/mp3/train/" \
    --input-extension ".mp3" \
    --output-dir gs://music2text/datasets/magnatagatune/wav/train/ \
    --runner DataflowRunner \
    --num-workers 512

# magnatagatune (test)
python scripts/convert_audio_to_wav.py \
    --input-dir "gs://music2text/datasets/magnatagatune/mp3/test/" \
    --input-extension ".mp3" \
    --output-dir gs://music2text/datasets/magnatagatune/wav/test/ \
    --runner DataflowRunner \
    --num-workers 512
"""

import argparse
import logging
import os
import tempfile
import time

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from google.cloud import storage


from m2t.audio_io import convert_to_wav
from m2t.gcs_utils import (
    GCP_PROJECT_NAME,
    GCS_BUCKET_NAME,
    US_CENTRAL1_REGION,
    US_CENTRAL1_SUBNETWORK,
    split_gcs_bucket_and_filepath,
    move_file,
)
from m2t.gcs_utils import list_files_with_extension


def process_file(infile, output_dir):
    if infile.startswith("gs://"):
        # Case: file is on GCS; download it first to use FFMPEG.
        gcs = storage.Client()

        bucket_src, filepath_src = split_gcs_bucket_and_filepath(infile)
        gcs_bucket_obj = gcs.get_bucket(bucket_src)
        filename = os.path.basename(filepath_src)
        blob_src = gcs_bucket_obj.blob(filepath_src)

        # Download the file to a temorary directory to process it with ffmpeg.
        with tempfile.TemporaryDirectory() as tmpdir:
            destination_file_path = os.path.join(tmpdir, filename)
            blob_src.download_to_filename(destination_file_path)
            converted_local_fp = convert_to_wav(destination_file_path, tmpdir)
            converted_dest_fp = os.path.join(
                output_dir, os.path.basename(converted_local_fp)
            )
            if not converted_local_fp:
                logging.warning(f"got no converted file for {filepath_src}")
                return

            move_file(
                converted_local_fp, converted_dest_fp, gcs_bucket_obj=gcs_bucket_obj
            )

    else:
        # Case: this is a local file.
        convert_to_wav(infile, output_dir)
    return


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-dir", required=True, help="path or wildcard to input files"
    )
    parser.add_argument(
        "--input-extension",
        default=".mp3",
        help="The extension of the files to process (e.g. '.mp3'). "
        "Files without this extension will be ignored."
        "Must be a file type that ffmpeg can read.",
    )

    parser.add_argument(
        "--output-dir",
        default=None,
        help="directory to output files, if replace-files is False.",
    )
    parser.add_argument(
        "--runner", default="DirectRunner", choices=["DirectRunner", "DataflowRunner"]
    )
    parser.add_argument("--job-name", default="music2text-convert-audio")
    parser.add_argument("--num-workers", default=32, help="max workers", type=int)
    parser.add_argument(
        "--worker-disk-size-gb",
        default=32,
        type=int,
        help="Worker disk size in GB. Note that disk size must be at least size of the docker image.",
    )
    parser.add_argument(
        "--machine-type", default="n1-standard-2", help="Worker machine type to use."
    )
    args = parser.parse_args()
    job_name = f"{args.job_name}-{int(time.time())}"
    print(f"job name is {job_name}")

    if args.runner == "DirectRunner":
        pipeline_options = {
            "runner": args.runner,
            "project": GCP_PROJECT_NAME,
            "temp_location": f"gs://{GCS_BUCKET_NAME}/dataflow-tmp",
        }
    else:
        pipeline_options = {
            "runner": args.runner,
            "project": GCP_PROJECT_NAME,
            "temp_location": f"gs://{GCS_BUCKET_NAME}/dataflow-tmp",
            "job_name": job_name,
            "region": US_CENTRAL1_REGION,
            "subnetwork": US_CENTRAL1_SUBNETWORK,
            "max_num_workers": args.num_workers,
            "worker_disk_type": "pd-ssd",
            "disk_size_gb": args.worker_disk_size_gb,
            "machine_type": args.machine_type,
            "save_main_session": True,
            "experiments": [
                "use_runner_v2",
                "beam_fn_api",
                "no_use_multiple_sdk_containers",
            ],
            "sdk_container_image": "gcr.io/audio-diffusion/m2t-preprocess:latest",
        }

    pipeline_options = PipelineOptions(**pipeline_options)

    # Read the wav audio and sample rate. Note that we do NOT allow to adjust
    # the sampple rate (and instead fix it at 44100) because this value is also
    # hard-coded in the Madmom code
    # (e.g. https://github.com/CPJKU/madmom/blob/3bc8334099feb310acfce884ebdb76a28e01670d/madmom/features/beats.py#L92)
    input_paths = list_files_with_extension(
        args.input_dir, extension=args.input_extension
    )
    print(
        f"processing {len(input_paths)} files; printing the first and last 10: {input_paths[:10]}"
    )
    if len(input_paths) > 10:
        print(input_paths[-10:])

    with beam.Pipeline(options=pipeline_options) as p:
        p |= "CreatePColl" >> beam.Create(input_paths) | "ProcessAudioFile" >> beam.Map(
            process_file, output_dir=args.output_dir
        )

    return


if __name__ == "__main__":
    main()
