# Copyright (c) Meta Platforms, Inc. and affiliates
"""
Sample call:

python balancing.py --captions_with_count_folder synthetic_captions_with_count \
                    --balanced_captions_folder balanced_curated_captions \
                    --metadata_filepath ./metadata.json \
                    --t 10

This code is adapted from the original Meta implementation to include the attribute from which the caption was generated, saving it at the first occurrence of a caption 
generated from the given attribute. Every following caption is assumed to be generated from the same attribute until a new attribute is found.
Also saves the concepts corresponding to the captions.
"""
import argparse
import json
import os
import random
from multiprocessing import Pool
import numpy as np
from tqdm import tqdm
import gc

def dump_json(data, filepath):
    with open(filepath, "w") as f:
        json.dump(data, f)
def load_json(filepath):
    with open(filepath, "r") as f:
        return json.load(f)

def balance_sampling(matched_entry_ids, entry_prob):
    for entry_id in matched_entry_ids: # this also ensures that empty captions are not kept
        if random.random() < entry_prob[entry_id]:
            return True
    return False

def main(args):
    gc.collect()
    attributes = ["background", "color", "concept", "lighting", "material", "perspective", "position", "style"]
    captions_with_count_folder = args.captions_with_count_folder
    balanced_captions_folder = args.balanced_captions_folder
    metadata_filepath = args.metadata_filepath
    t = args.t

    os.makedirs(balanced_captions_folder, exist_ok=True)
    save_dedup_captions_with_count_path = f"{balanced_captions_folder}/all_dedup_captions_with_count.json"
    save_concepts_of_dedup_captions_path = f"{balanced_captions_folder}/all_concepts_of_dedup_captions.json"
    save_starting_attribute_indices_path = f"{balanced_captions_folder}/starting_attribute_indices.json"
    save_entry_count_path = os.path.join(balanced_captions_folder, "entry_count.npy")

    metadata = load_json(metadata_filepath)

    if os.path.exists(save_dedup_captions_with_count_path) and os.path.exists(save_concepts_of_dedup_captions_path) and os.path.exists(save_starting_attribute_indices_path) and os.path.exists(save_entry_count_path): 
        print("Found existing D, D_concepts, start_index_attribute related files. Loading them...")
        D = load_json(save_dedup_captions_with_count_path)
        D_concepts = load_json(save_concepts_of_dedup_captions_path)
        start_index_attribute = {int(k): v for k, v in load_json(save_starting_attribute_indices_path).items()}
        entry_count = np.load(save_entry_count_path)
    else:
        print("No existing D, D_concepts, start_index_attribute related files found. Creating them...")
        entry_count = np.zeros(shape=(len(metadata),), dtype=np.uint64)

        D = []
        D_concepts = []
        start_index_attribute = {}
        captions = set()
        json_files = [
            f for f in os.listdir(captions_with_count_folder) if f.endswith(".json")
        ]
        print(f"There are {len(json_files)} json files.")

        for file in json_files:
            print(f"Getting captions from {file}")
            with open(os.path.join(captions_with_count_folder, file)) as f:
                parsed_json = json.load(f) # {"attribute": str, "captions": list of [text, [matched_entry_ids]] }
            
            attribute = parsed_json["attribute"]
            start_index_attribute[len(D)] = attribute # save the attribute at the starting index of the file
            for i, rec in tqdm(enumerate(parsed_json["captions"])):
                if rec[0] in captions: # skip duplicate captions
                    continue
                else:
                    concept = metadata[i % len(metadata)] # cycle through metadata if there are more captions than metadata entries
                    captions.add(rec[0]) # add to the set of unique captions
                    for entry in rec[1]: 
                        entry_count[entry] += 1 # update the count for each matched entry
                    D.append(rec) # add the record to the list
                    D_concepts.append(concept)
            del parsed_json
            gc.collect()
        del captions
        gc.collect()
        np.save(save_entry_count_path, entry_count)
        dump_json(D, save_dedup_captions_with_count_path)
        dump_json(D_concepts, save_concepts_of_dedup_captions_path)
        dump_json(start_index_attribute, save_starting_attribute_indices_path)

    print(f"There are {len(D)} unique captions, with corresponding {len(D_concepts)} concepts.")

    entry_count[entry_count < t] = t
    entry_prob = t / entry_count

    print("Sampling...")
    
    D_star = []
    D_star_ids = []
    D_star_concepts = []
    start_index_attribute_star = {}
    attribute = None
    for i, rec in tqdm(enumerate(D), total=len(D)):
        if not attribute:
            attribute = start_index_attribute.get(i, None)
        if balance_sampling(rec[1], entry_prob):
            if attribute:
                start_index_attribute_star[len(D_star)] = attribute
                attribute = None # reset attribute until a new one is found
            D_star.append(rec[0])
            D_star_ids.append(rec[1])
            D_star_concepts.append(D_concepts[i])

    print(f"Total of {len(D_star)} captions curated.")
    print(f"attributes found at: {start_index_attribute_star}")

    print("Generating final entry count for statistics...")
    final_entry_count = np.zeros(shape=(len(metadata),), dtype=np.uint64)
    for rec in D_star_ids:
        for entry in rec:
            final_entry_count[entry] += 1

    print("Saving results...")
    np.save(os.path.join(balanced_captions_folder, f"final_entry_count_t_{t}.npy"), final_entry_count)
    dump_json(D_star, os.path.join(balanced_captions_folder, f"curated_captions_t_{t}.json"))
    dump_json(D_star_ids, os.path.join(balanced_captions_folder, f"curated_captions_ids_t_{t}.json"))
    dump_json(D_star_concepts, os.path.join(balanced_captions_folder, f"curated_concepts_t_{t}.json"))
    dump_json(start_index_attribute_star, os.path.join(balanced_captions_folder, f"curated_attribute_indices_t_{t}.json"))

    print("Results saved.")
    print(f"If you are not satisfied with the resulting number of captions {len(D_star)}, increase or decrease t and re-run to get more or less captions, respectively.")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Arguments for balancing captions.")
    parser.add_argument(
        "--captions_with_count_folder",
        type=str,
        required=True,
        help="Name of the folder where the raw captions with count are",
    )
    parser.add_argument(
        "--balanced_captions_folder",
        type=str,
        required=True,
        help="Name of the folder where to store the curated captions",
    )
    parser.add_argument(
        "--metadata_filepath",
        type=str,
        required=True,
        help="Path to metadata (concept bank) file",
    )
    parser.add_argument(
        "--t",
        type=int,
        default=30,
        help="Hyperparameter t in MetaCLIP; controls the probability of sampling captions with concepts.",
    )

    args = parser.parse_args()

    main(args)
