import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import random
import tqdm
import sys


def main():
    # with open(f"counterfact/counterfact.json", "r") as f:
    #     data = json.load(f)
    with open(f"/run/determined/workdir/data/many-shot-icl/icl_test/data_construction/counterfact/counterfact.json", "r") as f:
        ori_data = json.load(f)
    print(len(ori_data))

    with open(f"/run/determined/workdir/data/many-shot-icl/icl_test/data_construction/dataset_july/edit_reph-subj.json", "r")as f:
        data = json.load(f)
    print(len(data))

    data_dict = {item["prompt"] + item["target_new"]: item for item in data}
    filtered_samples = [sample for sample in ori_data if sample["requested_rewrite"]["prompt"].format(sample["requested_rewrite"]["subject"]) + sample["requested_rewrite"]["target_true"]["str"] in data_dict.keys()]
    print(len(filtered_samples))

    paired_data = []
    for ori_dp in tqdm.tqdm(filtered_samples):
        prompt = ori_dp["requested_rewrite"]["prompt"].format(ori_dp["requested_rewrite"]["subject"]) + ori_dp["requested_rewrite"]["target_true"]["str"]
        dp = data_dict.get(prompt)
        
        if dp is None:
            print(f"Warning: No matching prompt found for {prompt}")
            continue
        paired_data.append({"ori": ori_dp, "dp":dp})
     

    results = []
    
    for d in tqdm.tqdm(paired_data):
        ori_dp = d["ori"]
        dp = d["dp"]
        if not ori_dp["requested_rewrite"]["prompt"].format(ori_dp["requested_rewrite"]["subject"]) == dp["prompt"]:
            # print(ori_dp["requested_rewrite"]["prompt"].format(ori_dp["requested_rewrite"]["subject"]), "\n", dp["prompt"])
            continue
        edit_sample = {
                "case_id": dp["case_id"], 
                "prompt": dp["prompt"],
                "target_new": ori_dp["requested_rewrite"]["target_new"]["str"],
                "target_true":ori_dp["requested_rewrite"]["target_true"]["str"],
                "subject": dp["subject"],
                "rephrase_prompt":dp["rephrase_prompt"], 
                "rephrase_target": ori_dp["requested_rewrite"]["target_new"]["str"],
                "locality_prompt":dp["locality_prompt"],
                "locality_ground_truth":ori_dp["requested_rewrite"]["target_true"]["str"],
                "subject_rephrase": dp["subject_rephrase"]
            }
        results.append(edit_sample)
    
    dir = "/run/determined/workdir/data/many-shot-icl/icl_test/data_construction/dataset_aug"
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
    
    print(len(results))
    with open(os.path.join(dir, f"counterfact_para_prompt_para_subject.json"), "w") as f:
        json.dump(results, f, indent=4)


    return




if  __name__ == "__main__":
    main()