import argparse
import json
import os
import uuid

import pandas as pd

from audio_utils import schemas as s


def prepare_audiocaps(input_path: str) -> list[s.InputCaption]:
    """Prepare AudioCaps dataset.

    :type input_path: file path to the input AudioCaps CSV dataset.
    """
    audiocaps = pd.read_csv(input_path)
    caption_list = audiocaps["caption"].to_list()
    metadata_list = audiocaps.drop(columns=["caption"]).to_dict("records")

    input_captions: list[s.InputCaption] = []
    for caption, metadata in zip(caption_list, metadata_list):
        input_caption: s.InputCaption = {
            "caption": caption,
            "metadata": {
                **metadata,
                "uid": uuid.uuid4().hex,
                "file_path": input_path
            }
        }
        input_captions.append(input_caption)

    return input_captions


def prepare_wavecaps(input_path: str) -> list[s.InputCaption]:
    input_captions: list[s.InputCaption] = []
    data = pd.read_json(input_path)
    data = pd.json_normalize(data["data"])

    for index, row in data.iterrows():
        input_caption: s.InputCaption = {
            "caption": row["caption"],
            "metadata": {
                "row_num": index,
                "id": row["id"],
                "file_path": os.path.join(row["href"], row["file_name"]),
                "uid": uuid.uuid4().hex,
            }
        }
        input_captions.append(input_caption)

    return input_captions


def prepare_audiosetsl(input_path) -> list[s.InputCaption]:
    data = pd.read_json(input_path)["data"]

    input_captions: list[s.InputCaption] = []
    for index, row in data.items():
        input_caption: s.InputCaption = {
            "caption": row["caption"],
            "metadata": {
                "row_num": index,
                "filename": row["id"],
                "uid": uuid.uuid4().hex,
            }
        }
        input_captions.append(input_caption)

    return input_captions


def main():
    """Main entry point for processing audio caption datasets."""
    parser = argparse.ArgumentParser(description="Process audio and music caption datasets into a standardized JSON format.")
    parser.add_argument("--input-path", required=False, type=str, help="File path to the input dataset (e.g., audiocaps.csv)")
    parser.add_argument("--output-path", required=True, type=str, help="File path where processed captions will be saved (e.g., processed_captions.jsonl)")
    parser.add_argument("--type", choices=["audiocaps", "wavcaps", "audioset-sl"], required=True, type=str, help="Type of audio caption dataset to process")
    args = parser.parse_args()

    # process dataset based on specified type
    if args.type == "audiocaps":
        input_captions = prepare_audiocaps(input_path=args.input_path)
    elif args.type == "wavcaps":
        input_captions = prepare_wavecaps(input_path=args.input_path)
    elif args.type == "audioset-sl":
        input_captions = prepare_audiosetsl(input_path=args.input_path)
    else:
        raise NotImplementedError

    # write processed captions in JSON Lines format
    with open(args.output_path, "w") as f:
        for input_caption in input_captions:
            f.write(json.dumps(input_caption) + "\n")


if __name__ == '__main__':
    main()
