import json
from pathlib import Path

import fire
import pandas as pd


def load_tango_test_ref(
    tango_test_ref:
    str = "data/audiocaps_v2/test_audiocaps_subset.json"
):
    aid_to_caption = {}
    with open(tango_test_ref, "r") as f:
        for line in f.readlines():
            item = json.loads(line)
            youtube_id = Path(item["location"]).stem[:11]
            aid_to_caption[youtube_id] = item["captions"]
    return aid_to_caption


def process_audio_id(audio_id: str):
    """Process audio_id to match the Youtube ID format."""
    audio_id = Path(audio_id).stem
    return audio_id[1:12]


def main(
    audiocaps_raw_dir: str = "/cpfs02/shared/speechllm/AudioCaps",
    target_dir: str = "./data/audiocaps_v2_kqq"
):
    test_tango_ref = load_tango_test_ref()
    audiocaps_raw_dir = Path(audiocaps_raw_dir)
    target_dir = Path(target_dir)

    audio_df = pd.read_csv(audiocaps_raw_dir / "wav.csv", sep="\t")
    audio_df["audio_id"] = audio_df["audio_id"].apply(process_audio_id)

    aid_to_fpath = dict(zip(audio_df["audio_id"], audio_df["file_name"]))
    available_aids = set(audio_df["audio_id"].values)

    for split in ["train", "val", "test"]:
        data_df = pd.read_csv(audiocaps_raw_dir / f"{split}.csv")
        (target_dir / split).mkdir(parents=True, exist_ok=True)
        processed_aids = set()
        with open(target_dir / split / "audio.jsonl", "w") as audio_writer, \
            open(target_dir / split / "caption.jsonl", "w") as text_writer:
            for i, row in data_df.iterrows():
                audio_id = row["youtube_id"]
                if audio_id not in available_aids:
                    continue
                if split == "test" and audio_id not in test_tango_ref:
                    continue
                if audio_id in processed_aids:
                    continue
                audio_writer.write(
                    json.dumps({
                        "audio_id": audio_id,
                        "audio": aid_to_fpath[audio_id]
                    }) + "\n"
                )
                if split == "test":
                    caption = test_tango_ref[audio_id]
                else:
                    caption = row["caption"]
                text_writer.write(
                    json.dumps({
                        "audio_id": audio_id,
                        "caption": caption
                    }) + "\n"
                )
                processed_aids.add(audio_id)


if __name__ == '__main__':
    fire.Fire(main)
