"""
Convert the GTZAN dataset to a chat-formatted JSON dataset.

Usage:
python scripts/gtzan_to_chat.py \
    --file-prefix gtzan-jukebox-f10-genre \
    --representations-dir datasets/gtzan/representations/jukebox/f10/ \
    --output-dir datasets/gtzan/preprocessed \
    --prompt-type "list_all"
"""
import glob
import os
import argparse
import numpy as np

from tqdm import tqdm
import webdataset as wds

from m2t.conversation_utils import make_example


PROMPT = "What genre is this song?"
PROMPT_WITH_CHOICES = (
    PROMPT
    + " Choose one of the following: blues, classical, country, disco, hiphop, jazz, metal, pop, reggae, rock."
)
PROMPT_LIST_ALL = PROMPT + " List all possible genres."

_PROMPTS = {
    "basic": PROMPT,
    "choices": PROMPT_WITH_CHOICES,
    "list_all": PROMPT_LIST_ALL,
}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--file-prefix", default="shard")
    parser.add_argument(
        "--representations-dir",
        required=True,
        help="path to directory containing .npy audio encodings",
    )
    parser.add_argument("--output-dir", required=True)

    parser.add_argument(
        "--prompt-type",
        choices=list(_PROMPTS.keys()),
        help="Type of prompt to use.",
    )
    parser.add_argument(
        "--ensemble-prompts",
        default=False,
        action="store_true",
        help="If true, create one example per prompt (for a total of num_samples * num_prompts output samples).",
    )
    args = parser.parse_args()

    assert os.path.exists(args.representations_dir)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    input_glob = os.path.join(args.representations_dir, "*.npy")
    input_files = glob.glob(input_glob)
    assert len(input_files), f"no files found matching {input_glob}"
    shard_fmt = os.path.join(args.output_dir, "-".join((args.file_prefix, "%06d.tar")))

    with wds.ShardWriter(shard_fmt, maxcount=1024) as sink:
        for f in tqdm(input_files):
            basename = os.path.basename(f)
            id = basename.replace(".npy", "")
            audio = basename.replace(".npy", ".wav")
            genre = basename.split(".")[0]
            audio_encoding = np.load(f)
            audio_encoding_shape = audio_encoding.shape
            if len(audio_encoding_shape) == 1:
                audio_encoding_shape = [1] + list(audio_encoding_shape)
            elem = make_example(
                id=id,
                audio=audio,
                audio_encoding=audio_encoding,
                audio_encoding_shape=audio_encoding_shape,
                prompt_question=_PROMPTS[args.prompt_type],
                response=genre,
            )
            audio_encoding = elem.pop("audio_encoding")
            del elem["audio_encoding_shape"]
            sink.write(
                {
                    # Webdataset cannot properly parse records with '.' chars in keys.
                    "__key__": elem["id"].replace(".", "_"),
                    "audio_encoding.pyd": audio_encoding,
                    "json": elem,
                }
            )
