"""
Convert the musiccaps dataset to a chat-formatted webdataset.

Usage:

python scripts/caption_dataset_to_chat_webdataset.py \
    --input-file datasets/yt8m-musictextclips/yt8m-musictextclips-all.json \
    --dataset-name yt8m-musictextclips \
    --representations-dir datasets/yt8m-musictextclips/representations/ \
    --output-dir datasets/yt8m-musictextclips/preprocessed/wds \
    --file-prefix yt8m-musictextclips-all-jukebox-f10-captioning
    
"""

import os
import pandas as pd
import argparse
import numpy as np
import random
from tqdm import tqdm

import webdataset as wds
from m2t.dataset_utils import DATASET_INFO
from m2t.conversation_utils import make_example
from m2t.instruct.captioning import CAPTIONING_PROMPTS

JSON_FIELDS_TO_KEEP = ["response"]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-file", required=True, help="Path to a JSON file.")
    parser.add_argument("--dataset-name", choices=list(DATASET_INFO.keys()))

    parser.add_argument(
        "--representations-dir",
        required=True,
        help="path to directory containing .npy audio encodings",
    )
    parser.add_argument("--output-dir", required=True)

    parser.add_argument("--file-prefix", default="shard")
    parser.add_argument(
        "--maxcount", default=128, type=int, help="max observations per tar file."
    )
    parser.add_argument(
        "--ensemble-prompts",
        default=False,
        action="store_true",
        help="If true, create one example per prompt (for a total of num_samples * num_prompts output samples).",
    )
    parser.add_argument(
        "--trim-json",
        default=True,
        action=argparse.BooleanOptionalAction,
        help="If true, will only keep the 'response' field in the JSON component of the"
        " record. This helps keep entries small by avoiding storing arbitrarily large "
        "metadata fields. Set --no-trim-json to disable. Note that the id of the "
        "example is always stored as the key.",
    )
    args = parser.parse_args()

    assert os.path.exists(args.representations_dir)
    assert args.input_file.endswith(".json"), "expect a JSON file as input."
    df = pd.read_json(args.input_file, lines=True)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    prompts = CAPTIONING_PROMPTS[args.dataset_name]
    dataset_info = DATASET_INFO[args.dataset_name]
    id_colname = dataset_info.id_col

    shard_fmt = os.path.join(args.output_dir, "-".join((args.file_prefix, "%06d.tar")))

    with wds.ShardWriter(shard_fmt, maxcount=args.maxcount) as sink:
        for _, row in tqdm(df.iterrows(), total=len(df)):
            id = row[id_colname]
            audio = row[id_colname] + ".wav"
            encoding_fp = os.path.join(args.representations_dir, id + ".npy")
            try:
                audio_encoding = np.load(encoding_fp)
            except FileNotFoundError:
                print(f"no encodings found for {encoding_fp}; skipping")
                continue
            audio_encoding_shape = audio_encoding.shape
            if len(audio_encoding_shape) == 1:
                audio_encoding_shape = [1] + list(audio_encoding_shape)
            caption = row[dataset_info.caption_col]
            if not args.ensemble_prompts:
                # Choose one prompt randomly.
                prompt_question = random.choice(prompts)
                elem = make_example(
                    id=id,
                    audio=audio,
                    audio_encoding=audio_encoding,
                    audio_encoding_shape=audio_encoding_shape,
                    prompt_question=prompt_question,
                    response=caption,
                )
            else:
                # Create one example for every prompt with the given example.
                for prompt_question in prompts:
                    elem = make_example(
                        id=id,
                        audio=audio,
                        audio_encoding=audio_encoding,
                        audio_encoding_shape=audio_encoding_shape,
                        prompt_question=prompt_question,
                        response=caption,
                    )

            audio_encoding = elem.pop("audio_encoding")
            del elem["audio_encoding_shape"]

            sink.write(
                {
                    # Webdataset cannot properly parse records with '.' chars in keys.
                    "__key__": elem["id"].replace(".", "_"),
                    "audio_encoding.pyd": audio_encoding,
                    "json": elem,
                }
            )
