import pickle
import pandas as pd
import json

SAVE_PATH = "../../outputs/evalgim_metadata/"


def get_dsg_data(complexity):
    dsg_path = "../../../outputs/dsg_cc12m/"
    dsg_questions = {}
    dsg_children = {}
    dsg_parents = {}
    for chunk_id in range(10):
        dsg_fname = f"dsg_complexity_{complexity}_chunk_{chunk_id}.csv"
        df = pd.read_csv(f"{dsg_path}/{dsg_fname}")
        for caption_id in range(chunk_id * 500, chunk_id * 500 + 500):
            selected_rows = df[df['item_id'] == f"complexity_{complexity}_caption_{caption_id}"]
            assert not selected_rows.empty
            caption = selected_rows['text'].values[0]
            if 'Historic ceremony' in caption:
                print('Historic ceremony found in caption:', caption)
            if caption not in dsg_questions:
                dsg_questions[caption] = selected_rows['question_natural_language'].tolist()
                dsg_children[caption] = selected_rows['proposition_id'].tolist()
                dsg_parents[caption] = selected_rows['dependency'].tolist()
    return dsg_questions, dsg_children, dsg_parents


def convert_to_json_format_gen(data, complexity):
    dsg_questions, dsg_children, dsg_parents = get_dsg_data(complexity)

    image_names = list(data.keys())
    output_list = []
    for image_name in image_names:
        caption = data[image_name]['caps'][complexity]
        output_dict = {
            'image_path': f"/tmp/generations/{image_name}.png",
            "prompt": caption,
            "dsg_questions": dsg_questions[caption],
            "dsg_children": dsg_children[caption],
            "dsg_parents": dsg_parents[caption],
            "condition": {"class_id": caption}
        }
        output_list.append(output_dict)

    with open(
        f"{SAVE_PATH}/evalgim_genldm_metadata_cc12m/dsg_c{complexity}/index.json",
        "w"
    ) as f:
        json.dump(output_list, f, indent=4)


if __name__ == "__main__":
    with open("../../metadata/cc12m/full_dict_gemma3_siglip_eval_clean_5k_4caps.pkl", "rb") as f:
        data = pickle.load(f)
    for complexity in range(4):
        convert_to_json_format_gen(data, complexity)