"""
Convert an instruction-tuning dataset to webdataset format for training.

If a dataset has captions only (the 'caption' field of response is populated),
    then it will be converted to a question/answer format.

Usage:

# musiccaps train mir
python scripts/openai_instruct_data_to_webdataset.py \
    --input-dir datasets/musiccaps/instruct/train/mir/gpt-3.5-06-29  \
    --output-dir datasets/musiccaps/preprocessed/train \
    --representations-dir datasets/musiccaps/representations/jukebox/f10 \
    --file-prefix "musiccaps-train-jukebox-f10-mir"

For the commands used for all datasets, see preprocess_all_datasets_to_webdataset.sh.
"""
import argparse
import json
import os
import logging
import sys
import tempfile
from typing import Dict, Any, Union

import pandas as pd
from tqdm import tqdm
import numpy as np
import webdataset as wds
import torch
from google.cloud import storage

from m2t.arguments import write_args_to_file
from m2t.dataset_utils import get_cropped_uri, DATASET_INFO, read_ids_file
from m2t.gcs_utils import list_files_with_extension
from m2t.instruct.captioning import is_caption_resonse
from m2t.instruct.captioning import insert_caption_qa
from m2t.instruct.data_validation import (
    element_response_is_not_exception,
    drop_invalid_qa_responses,
    element_is_valid_strict,
)


def read_jsonl(f):
    with open(f, "r") as handle:
        lines = handle.readlines()
    return [json.loads(x) for x in lines]


def read_audio_encoding(
    uri: str, representations_dir: str
) -> Union[None, torch.Tensor]:
    if not isinstance(uri, str):
        logging.debug(f"casting uri {uri} of type {type(uri)} to string {str(uri)}")
        uri = str(uri)
    audio_filename = uri + ".wav"

    encoding_fp = os.path.join(representations_dir, uri + ".npy")
    audio_encoding = None

    if encoding_fp.startswith("gs://"):
        # Case: file located on GCS

        # Create a Cloud Storage client and parse the path.
        gcs = storage.Client()
        bucket, file_name = encoding_fp.replace("gs://", "").split("/", maxsplit=1)
        gcs_bucket_obj = gcs.get_bucket(bucket)

        # Download the file locally and read it
        blob = gcs_bucket_obj.blob(file_name)

        if not blob.exists():
            # Case: encoding does not exist on GCS.
            logging.warning(f"no encodings found for {encoding_fp}; skipping")
            audio_encoding = None

        else:
            # Case: encoding exists on GCS; load it.
            with tempfile.TemporaryDirectory() as tmp:
                encoding_fp_local = os.path.join(tmp, audio_filename)
                logging.info(f"downloading {encoding_fp} to {encoding_fp_local}")
                blob.download_to_filename(encoding_fp_local)
                logging.info(f"loading downloaded file from {encoding_fp_local}")
                audio_encoding = np.load(encoding_fp_local)

    else:
        # Case: file is local.
        try:
            logging.debug(f"reading local encoding file from {encoding_fp}")
            audio_encoding = np.load(encoding_fp)
        except FileNotFoundError:
            logging.warning(f"no encodings found for {encoding_fp}; skipping")

    if audio_encoding is not None:
        return torch.from_numpy(audio_encoding)
    else:
        return audio_encoding


def get_uri(elem: Dict[str, Any]) -> str:
    raw_uri = str(elem["uri"])
    if ("start_secs" in elem and "end_secs" in elem) and not (
        "start" in raw_uri and "end" in raw_uri
    ):
        # Case: the metadata is for a cropped piece of audio,
        # but the URI is for the full audio. Fetch the 'cropped' URI.
        uri = get_cropped_uri(
            raw_uri, start_secs=elem["start_secs"], end_secs=elem["end_secs"]
        )
    else:
        uri = raw_uri
    return uri


JSON_FIELDS_TO_KEEP = ["response"]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-dir", required=True)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--dataset-name")
    parser.add_argument("--file-prefix", default="shard")
    parser.add_argument(
        "--ids-file",
        default=None,
        help="Newline-delimited text file containing the IDs to include in this split.",
    )
    parser.add_argument(
        "--trim-json",
        default=True,
        action=argparse.BooleanOptionalAction,
        help="If true, will only keep the 'response' field in the JSON component of the"
        " record. This helps keep entries small by avoiding storing arbitrarily large "
        "metadata fields. Set --no-trim-json to disable. Note that the id of the "
        "example is always stored as the key.",
    )
    parser.add_argument("--filter-invalid-responses", action="store_true", default=True)
    parser.add_argument(
        "--maxcount", default=128, type=int, help="max observations per tar file."
    )
    parser.add_argument(
        "--min-duration",
        type=float,
        default=None,
        help="Minimum duration in seconds. If specified, clips shorter "
        "than this will be dropped.",
    )
    parser.add_argument(
        "--representations-dir",
        required=True,
        help="path to directory containing .npy audio encodings",
    )

    args = parser.parse_args()
    write_args_to_file(sys.argv, args.output_dir)

    if args.ids_file:
        ids = read_ids_file(args.ids_file)
        print(f"[INFO] got {len(ids)} ids for split: {ids}")

    input_paths = list_files_with_extension(args.input_dir, extension=".jsonl")
    logging.info(f"processing {len(input_paths)} files from {args.input_dir}")
    elems = [x for y in input_paths for x in read_jsonl(y)]

    if not os.path.exists(args.output_dir):
        logging.info(f"creating output directory {args.output_dir}")
        os.makedirs(args.output_dir)

    filtered_samples = 0
    output_response_count = 0  # count of valid responses

    shard_fmt = os.path.join(args.output_dir, "-".join((args.file_prefix, "%06d.tar")))
    with wds.ShardWriter(shard_fmt, maxcount=args.maxcount) as sink:
        for elem in tqdm(elems):
            uri = get_uri(elem)

            if not element_response_is_not_exception(elem):
                # Case: invalid caption resposne
                logging.warning(f"invalid ChatGPT caption for elem {uri}; skipping")
                filtered_samples += 1
                continue

            if args.min_duration is not None and (
                elem.get("end_secs", 0) - elem.get("start_secs", 0) < args.min_duration
            ):
                logging.warning(
                    f"dropping element {uri} with duration "
                    f'{elem.get("end_secs", 0) - elem.get("start_secs", 0)}'
                )
                filtered_samples += 1
                continue

            if args.ids_file and str(elem["id"]) not in ids:
                print(f"[DEBUG] skipping id {elem['id']} not in ids file.")
                continue

            if is_caption_resonse(elem):
                # Case: valid caption resonse; handle it
                assert (
                    args.dataset_name
                ), "--dataset-name flag is required for caption data."
                elem = insert_caption_qa(
                    elem,
                    caption_prompts=DATASET_INFO[args.dataset_name].caption_prompts,
                )
            else:
                # Case: valid QA response; handle it.
                elem = drop_invalid_qa_responses(elem)

            if not element_is_valid_strict(elem):
                filtered_samples += 1
                continue

            audio_encoding = read_audio_encoding(uri, args.representations_dir)
            if audio_encoding is None:
                logging.warning(f"no encoding for uri {uri}; skipping")
                continue

            if args.trim_json:
                output_json = {k: elem[k] for k in JSON_FIELDS_TO_KEEP}
            else:
                output_json = elem

            output_response_count += len(output_json["response"])

            sink.write(
                {
                    # Webdataset cannot properly parse records with '.' chars in keys.
                    "__key__": uri.replace(".", "_"),
                    "audio_encoding.pyd": audio_encoding,
                    "json": output_json,
                }
            )
    print(
        f"finished; wrote {output_response_count} samples and removed {filtered_samples}"
        " samples with invalid responses after filtering."
    )
