import argparse
import json
import os
import random
from multiprocessing import Pool
import numpy as np
from tqdm import tqdm
import gc

def dump_json(data, filepath):
    with open(filepath, "w") as f:
        json.dump(data, f)
def load_json(filepath):
    with open(filepath, "r") as f:
        return json.load(f)

def main(args):
    random.seed(args.seed)
    curated_captions_path = os.path.join(args.balanced_captions_folder, f"curated_captions_t_{args.t}.json")
    curated_concepts_path = os.path.join(args.balanced_captions_folder, f"curated_concepts_t_{args.t}.json")
    start_index_attributes_path = os.path.join(args.balanced_captions_folder, f"curated_attribute_indices_t_{args.t}.json")
    captions = load_json(curated_captions_path)
    concepts = load_json(curated_concepts_path)
    start_index_attributes = {int(k): v for k, v in load_json(start_index_attributes_path).items()}
    
    # Reconstruct the attribute for each caption
    print("Reconstructing attributes for each caption...")
    index_to_attribute = [None] * len(captions)
    sorted_keys = sorted(start_index_attributes.keys())
    for i in range(len(sorted_keys)):
        start_index = sorted_keys[i]
        end_index = len(captions)
        if i + 1 < len(sorted_keys):
            end_index = sorted_keys[i+1]
        
        attribute = start_index_attributes[start_index]
        for j in range(start_index, end_index):
            index_to_attribute[j] = attribute

    # Sample indices
    print(f"Sampling {args.num_final_captions} captions from {len(captions)} available captions...")
    if args.num_final_captions >= len(captions):
        print(f"Warning: num_final_captions ({args.num_final_captions}) is larger than or equal to the number of available captions ({len(captions)}). Using all captions.")
        sampled_indices = list(range(len(captions)))
    else:
        sampled_indices = sorted(random.sample(range(len(captions)), args.num_final_captions))

    # Create final lists
    final_captions = [captions[i] for i in sampled_indices]
    final_concepts = [concepts[i] for i in sampled_indices]

    # Create final_start_index_attributes
    print("Creating final start index attributes...")
    final_start_index_attributes = {}
    if len(final_captions) > 0:
        final_attributes = [index_to_attribute[i] for i in sampled_indices]
        
        # First attribute
        if final_attributes[0] is not None:
            final_start_index_attributes[0] = final_attributes[0]
        
        # Find where attribute changes
        for i in range(1, len(final_attributes)):
            if final_attributes[i] != final_attributes[i-1]:
                if final_attributes[i] is not None:
                    final_start_index_attributes[i] = final_attributes[i]

    # Save the results
    os.makedirs(args.final_captions_folder, exist_ok=True)
    
    output_captions_path = os.path.join(args.final_captions_folder, f"captions_{args.num_final_captions}.json")
    output_concepts_path = os.path.join(args.final_captions_folder, f"concepts_{args.num_final_captions}.json")
    output_attributes_path = os.path.join(args.final_captions_folder, f"attribute_indices_{args.num_final_captions}.json")

    dump_json(final_captions, output_captions_path)
    dump_json(final_concepts, output_concepts_path)
    dump_json(final_start_index_attributes, output_attributes_path)

    print(f"Saved {len(final_captions)} captions, concepts, and attribute indices to {args.final_captions_folder}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Arguments for downsampling the curated captions to the desired number.")
    parser.add_argument(
        "--balanced_captions_folder",
        type=str,
        required=True,
        help="Name of the folder where to store the curated captions",
    )
    parser.add_argument(
        "--final_captions_folder",
        type=str,
        required=True,
        help="Name of the folder where to store the final captions",
    )
    parser.add_argument(
        "--metadata_filepath",
        type=str,
        required=True,
        help="Path to metadata (concept bank) file",
    )
    parser.add_argument(
        "--num_final_captions",
        type=int,
        default=12000000,
        help="Number of final captions to sample.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--t",
        type=int,
        default=30,
        help="Hyperparameter t used in the previous balancing step.",
    )

    args = parser.parse_args()

    main(args)