import os
import pickle
from collections import defaultdict

DATA_PATH = "../metadata/"


def gather_captioning_multi_complexity(num_complexity_levels, chunk_num):
    output_dict = defaultdict({'caps': []})

    for complexity_level in range(num_complexity_levels):
        for index in range(chunk_num):
            with open(f"{DATA_PATH}/cc12m/gemma3_captions/gemma3_{complexity_level}obj_{index}.txt", "r") as f:
                captions = f.readlines()
            for cap in captions:
                cap = cap.strip()
                img_id, caption = cap.split("*****")
                output_dict[int(img_id)]['caps'].append(caption)
    
    with open(f"{DATA_PATH}/cc12m/full_dict_gemma3_eval_clean_4caps.pkl", "wb") as f:
        pickle.dump(output_dict, f)


if __name__ == "__main__":
    # Change this to the number of complexity levels you have
    num_complexity_levels = 4
    chunk_num = 196  # change this to the number of chunks you have
    gather_captioning_multi_complexity(num_complexity_levels, chunk_num)