# Copyright (c) Meta Platforms, Inc. and affiliates
"""
Sample call:

python substr_matching.py --synthetic_captions_folder synthetic_captions \
                          --captions_with_count_folder synthetic_captions_with_count \
                          --metadata_filepath ./metadata.json \
                          --num_processes 8
                             

This code is adapted from the original Meta implementation
"""
import argparse
import json
import multiprocessing
import os
from itertools import repeat

from tqdm import tqdm

spaced_metadata = None

def spacing(text):
    puncts_to_wrap = [",", ".", ";", ":", "?", "!", "`"]
    chars_to_space = ["\t", "\n", "\r"]

    spaced_text = f" {text} "
    for punct_to_wrap in puncts_to_wrap:
        spaced_text = spaced_text.replace(punct_to_wrap, f" {punct_to_wrap} ")
    for char_to_space in chars_to_space:
        spaced_text = spaced_text.replace(char_to_space, " ")
    return spaced_text

def substr_matching(text, metadata):
    text = spacing(text)
    matched_entry_ids = []
    for entry_id, entry in enumerate(metadata):  # metadata is already spaced
        if entry in text:
            matched_entry_ids.append(entry_id)
    return matched_entry_ids

def init_pool(metadata):
    global _shared_metadata
    _shared_metadata = [f" {entry.lower()} " for entry in metadata]  # spaced once

def dist_func_text(text):
    global _shared_metadata
    return [text, substr_matching(text, _shared_metadata)]

def main(args):

    synthetic_captions_folder = args.synthetic_captions_folder
    captions_with_count_folder = args.captions_with_count_folder
    metadata_filepath = args.metadata_filepath
    num_processes = int(args.num_processes)

    os.makedirs(captions_with_count_folder, exist_ok=True)

    with open(metadata_filepath, "r") as f:
        metadata = json.load(f)

    json_files = [
        f for f in os.listdir(synthetic_captions_folder) if f.endswith(".json")
    ]
    print(f"There are {len(json_files)} json files.")

    for file in json_files:
        print(f"Processing file {file}")
        with open(os.path.join(synthetic_captions_folder, file), "r") as f:
            parsed_json = json.load(f)

        raw_text = [
            text.replace('"', "").strip(" ").strip("\n")
            for text in tqdm(parsed_json["captions"])
        ]

        with multiprocessing.Pool(
            processes=num_processes,
            initializer=init_pool,
            initargs=(metadata,)
        ) as pool:
            text_with_count = list(
                tqdm(pool.imap(dist_func_text, raw_text), total=len(raw_text))
            )

        with open(os.path.join(captions_with_count_folder, file), "w") as f:
            parsed_json["captions"] = text_with_count
            json.dump(parsed_json, f)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Arguments for substring matching.")
    parser.add_argument(
        "--synthetic_captions_folder",
        type=str,
        required=True,
        help="Name of the folder where the raw caption json files are",
    )
    parser.add_argument(
        "--captions_with_count_folder",
        type=str,
        required=True,
        help="Name of the folder path where to save caption with count json files",
    )
    parser.add_argument(
        "--metadata_filepath",
        type=str,
        required=True,
        help="Path to metadata (concept bank) file",
    )
    parser.add_argument(
        "--num_processes",
        type=int,
        default=64,
        help="Number of processes to use",
    )

    args = parser.parse_args()

    main(args)