"""
Convert the GiantSteps dataset to chat-formatted webdataset.

Usage:
python scripts/giantsteps_to_chat.py \
    --input datasets/giantsteps/giantsteps-key.json \
    --task key \
    --representations-dir datasets/giantsteps/giantsteps-key-representations/jukebox/f10/ \
    --output-dir datasets/giantsteps/key/preprocessed/wds \
    --file-prefix giantsteps-eval-jukebox-f10-key

python scripts/giantsteps_to_chat.py \
    --input datasets/giantsteps/giantsteps-tempo.json \
    --task tempo \
    --representations-dir datasets/giantsteps/giantsteps-tempo-representations/jukebox/f10/ \
    --output-dir datasets/giantsteps/tempo/preprocessed/wds \
    --file-prefix giantsteps-eval-jukebox-f10-tempo
"""
import os
import pandas as pd
import argparse
from typing import Dict
from tqdm import tqdm

import webdataset as wds
from m2t.conversation_utils import make_example
from m2t.dataset_utils import (
    format_examples_for_model,
    maybe_trim_json,
    read_audio_encoding,
)

PROMPTS: Dict[str, str] = {
    "key": "What is the key of this song?",
    "tempo": "What is the tempo of this song?",
}

JSON_FIELDS_TO_KEEP = ["response"]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input",
        required=True,
        type=str,
        help="Path to JSON file with preprocessed GiantSteps labels.",
    )
    parser.add_argument("--task", choices=["tempo", "key"])
    parser.add_argument(
        "--representations-dir",
        required=True,
        help="path to directory containing .npy audio encodings",
    )
    parser.add_argument("--output-dir", required=True)

    parser.add_argument("--file-prefix", default="shard")
    parser.add_argument(
        "--maxcount", default=128, type=int, help="max observations per tar file."
    )
    parser.add_argument(
        "--ensemble-prompts",
        default=False,
        action="store_true",
        help="If true, create one example per prompt (for a total of "
        "num_samples * num_prompts output samples).",
    )
    parser.add_argument(
        "--trim-json",
        default=True,
        action=argparse.BooleanOptionalAction,
        help="If true, will only keep the 'response' field in the JSON component of the"
        " record. This helps keep entries small by avoiding storing arbitrarily large "
        "metadata fields. Set --no-trim-json to disable. Note that the id of the "
        "example is always stored as the key.",
    )
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    df = pd.read_json(args.input, lines=True)

    shard_fmt = os.path.join(args.output_dir, "-".join((args.file_prefix, "%06d.tar")))

    id_colname = "id"
    with wds.ShardWriter(shard_fmt, maxcount=args.maxcount) as sink:
        for _, elem in tqdm(df.iterrows(), total=len(df)):
            id = elem["id"] + "-start30.000-end60.000"
            audio = id + ".wav"

            audio_encoding = read_audio_encoding(
                id, args.representations_dir, numpy_to_torch=False
            )
            if audio_encoding is not None:
                elem["audio_encoding_shape"] = (
                    list(audio_encoding.shape) if audio_encoding is not None else None
                )
                elem["audio_encoding"] = audio_encoding

            else:
                print(
                    "[DEBUG] got element response with no audio encoding "
                    f"with id {elem[id_colname]}"
                )
                continue

            prompt_question = PROMPTS[args.task]
            elem = make_example(
                id=elem[id_colname],
                audio=elem[id_colname] + ".wav",
                audio_encoding=elem["audio_encoding"],
                audio_encoding_shape=elem["audio_encoding_shape"],
                prompt_question=prompt_question,
                response="<EMPTY>",
            )
            elem = format_examples_for_model(elem)
            elem = maybe_trim_json(
                elem, fields_to_keep=JSON_FIELDS_TO_KEEP, trim_json=args.trim_json
            )
            audio_encoding = elem.pop("audio_encoding")
            del elem["audio_encoding_shape"]
            sink.write(
                {
                    # Webdataset cannot properly parse records with '.' chars in keys.
                    "__key__": id.replace(".", "_"),
                    "audio_encoding.pyd": audio_encoding,
                    "json": elem["json"],
                }
            )
