import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy
import matplotlib.pyplot as plt
import argparse 

def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]
    
def split(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    train, test = [], []
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    for i in tqdm(range(len(arr))):
        if i in rand_inds:
            train.append(arr[i])
        else:
            test.append(arr[i])
    return [train, test]

def form_items(c, t):
    input_text = "".join(c)
    target_text = input_text + "".join([t, "</a>"])
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item

def form_item_counterfactual(c, t, unrel, original_t, hop):
    input_text = "".join(c)
    target_text = input_text + "".join([t, "</a>"])
    item = {
        "input_text": input_text,
        "target_text": target_text,
        "unrel":unrel,
        "original_t":original_t, 
        "hop":hop
    }
    return item
def form_item_cot_counterfactual(c, cot ,t, unrel, original_t, hop):
    input_text = "".join(c)
    target_text = input_text +"".join(cot)+ "".join([t, "</a>"])
    item = {
        "input_text": input_text,
        "target_text": target_text,
        "unrel":unrel,
        "original_t":original_t, 
        "hop":hop
    }
    return item
def build_dataset_with_pretrain_held_out_cf_ent_multi(num_entities, num_relations, out_degree=20, args={}, split_train_inferred=False, generate_in_prompt_cf = True, cot = False, multi=0,):
    entities = ["<e_{}>".format(i) for i in range(num_entities)]
    ind2entity, entity2ind = build_dicts(entities)
    relations = ["<r_{}>".format(i) for i in range(num_relations)]
    ind2relation, relation2ind = build_dicts(relations)
    atomic_dict = dict()   # maps a head entity to a list of (r, t) pairs
    atomic_facts = []
    atomics = []
    for i in tqdm(range(num_entities)):
        # for each subject entity, randomly select some outgoing relations to some random object entity
        num_rows = args.out_degree
        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()
        for row_idx in selected_rows:
            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)
            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]
            atomic_facts.append(form_items([h, r], t))
            atomics.append((h,r,t))
            if h not in atomic_dict:
                atomic_dict[h] = []
            atomic_dict[h].append((r, t))
    print(len(atomics))
    if not split_train_inferred:
        inferred_facts = []
        for ent in tqdm(entities):
            for (r1, b) in atomic_dict[ent]:
                for (r2, t) in atomic_dict[b]:
                    inferred_facts.append(form_items([ent, r1, r2], t))
        return entities, relations, atomic_facts, inferred_facts
    # split ID/OOD
    OOD_ratio = args.ood_fraction
    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))
    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)
    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]
    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]
    train_inferred_facts, test_inferred_iid, test_inferred_ood = [], [], []
    for ent in tqdm(entities):
        for (r1, b) in atomic_dict[ent]:
            for (r2, t) in atomic_dict[b]:
                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                    if (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:
                        test_inferred_ood.append(form_items([ent, r1, r2], t))
                    continue
                if np.random.uniform() > 0.005:
                    if args.cf_fmt_at_pt:
                        if np.random.uniform()>0.5:
                            train_inferred_facts.append(form_item_counterfactual([ent, r1, b, ent, r1, r2], t, unrel = False,original_t=t, hop=1))
                        else:
                            train_inferred_facts.append(form_item_counterfactual([b, r2, t, ent, r1, r2], t,  unrel = False, original_t=t, hop=2))
                    else:
                        train_inferred_facts.append(form_item_counterfactual([ent, r1, r2], t, unrel = False,original_t=t, hop=1))
                else:
                    if args.cf_fmt_at_pt:
                        if np.random.uniform()>0.5:
                            test_inferred_iid.append(form_item_counterfactual([ent, r1, b, ent, r1, r2], t, unrel=False, original_t=t, hop=1))
                        else:
                            test_inferred_iid.append(form_item_counterfactual([b, r2, t, ent, r1, r2], t, unrel = False , original_t = t, hop=2))
                    else:
                        test_inferred_iid.append(form_item_counterfactual([ent, r1, r2], t, unrel = False,original_t=t, hop=1))
    counterfactual_train_facts, counterfactual_test_facts = [], []
    cf_ood_train, cf_train_id = split(list(ID_facts), round(len(ID_facts)*args.cf_ood_fraction))
    cf_ood_train, cf_train_id = set(cf_ood_train), set(cf_train_id)
    factual_train_example = []
    factual_test_example = []
    for ent in tqdm(entities):
        for (r1, b) in atomic_dict[ent]: 
            for (r2, t) in atomic_dict[b]:
                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                    continue
                if ((ent, r1, b) in cf_train_id) and ((b, r2, t) in cf_train_id):
                    factual_train_example.append(form_item_cot_counterfactual([ent, r1, r2], [ent, r1, b, r2, t], t,False,t, hop=1))
                else:
                    factual_test_example.append(form_item_cot_counterfactual([ent, r1, r2], [ent, r1, b, r2, t], t,False,t, hop=1))
    if generate_in_prompt_cf:
        for ent in tqdm(entities):
            for (r1, b) in atomic_dict[ent]: 
                for (r2, t) in atomic_dict[b]:
                    if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                        continue 
                    if ((ent, r1, b) in cf_train_id) and ((b, r2, t) in cf_train_id):
                        inserted_counterfactual = set()
                        for multi_idx in range(multi):
                            counterfactual_found = False
                            while not counterfactual_found:
                                bridge_cf = random.choice(entities)
                                bridges = atomic_dict[bridge_cf].copy()
                                random.shuffle(bridges)
                                for (rsub,tsub) in bridges:
                                    if (rsub==r2) and (tsub!=t) and not (tuple((bridge_cf,tsub)) in inserted_counterfactual):
                                        counterfactual_train_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1) if (not cot) else form_item_cot_counterfactual([ent, r1, bridge_cf, ent, r1, r2], [ent, r1, bridge_cf, rsub, tsub], tsub,False,t, hop=1))
                                        inserted_counterfactual.add(tuple((bridge_cf,tsub)))
                                        counterfactual_found = True 
                                        break 
                                if counterfactual_found:break 
                        ###Add a counterfactual for the "second hop"
                        inserted_counterfactual = set()
                        for multi_idx in range(multi):
                            second_counterfactual_found = False
                            second_ent_sub = random.choice(entities)
                            while second_ent_sub == t or (second_ent_sub in inserted_counterfactual) :
                                print("loop")
                                second_ent_sub = random.choice(entities)
                            inserted_counterfactual.add(second_ent_sub)
                            counterfactual_train_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2) if (not cot) else form_item_cot_counterfactual([b,r2, second_ent_sub, ent, r1, r2],[ent,r1,b,r2,second_ent_sub], second_ent_sub, False, t, hop=2))
                        ##Randomly select an unrelated entity 
                        inserted_counterfactual = set()
                        for multi_idx in range(multi):
                            unrel_entity_1 = random.choice(entities)
                            unrel_entity_2 = random.choice(entities)
                            while (unrel_entity_1 == b) or (unrel_entity_1 == ent) or (tuple((unrel_entity_1, unrel_entity_2)) in inserted_counterfactual):
                                print("loop")
                                unrel_entity_1 = random.choice(entities)
                            inserted_counterfactual.add((unrel_entity_1, unrel_entity_2))
                            counterfactual_train_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t, True, t, hop=1) if (not cot) else form_item_cot_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], [ent,r1,b,r2,t], t, True, t, hop=1))
                        ###Create irrelevant second-hop fact
                        inserted_counterfactual = set()
                        for multi_idx in range(multi):
                            unrel_secondhop_entity_1 = random.choice(entities)
                            unrel_secondhop_entity_2 = random.choice(entities)
                            while (unrel_secondhop_entity_1 == b) or (unrel_secondhop_entity_1 == ent) or (tuple((unrel_secondhop_entity_1,unrel_secondhop_entity_2)) in inserted_counterfactual):
                                print("loop")
                                unrel_secondhop_entity_1 = random.choice(entities)
                            inserted_counterfactual.add((unrel_secondhop_entity_1,unrel_secondhop_entity_2))
                            counterfactual_train_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2, unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop = 2) if (not cot) else form_item_cot_counterfactual([unrel_secondhop_entity_1, r2, unrel_secondhop_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t, True,t, hop = 2))
                    else:
                        counterfactual_found = False
                        while not counterfactual_found:
                            bridge_cf = random.choice(entities)
                            bridges = atomic_dict[bridge_cf].copy()
                            random.shuffle(bridges)
                            for (rsub,tsub) in bridges:
                                if (rsub==r2) and (tsub!=t):
                                    counterfactual_test_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1) if (not cot) else form_item_cot_counterfactual([ent, r1, bridge_cf, ent, r1, r2],[ent, r1, bridge_cf, rsub, tsub], tsub,False,t, hop=1)) 
                                    counterfactual_found = True 
                                    break 
                            if counterfactual_found:break 
                        ###Add a counterfactual for the "second hop"
                        second_counterfactual_found = False
                        second_ent_sub = random.choice(entities)
                        while second_ent_sub == t:
                            second_ent_sub = random.choice(entities)
                        counterfactual_test_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2) if (not cot) else form_item_cot_counterfactual([b,r2, second_ent_sub, ent, r1, r2],[ent,r1,b,r2,second_ent_sub], second_ent_sub, False, t, hop=2))
                        ##Randomly select an unrelated entity 
                        unrel_entity_1 = random.choice(entities)
                        unrel_entity_2 = random.choice(entities)
                        while (unrel_entity_1 == b) or (unrel_entity_1 == ent):
                            unrel_entity_1 = random.choice(entities)
                        counterfactual_test_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t,True, t, hop=1) if (not cot) else form_item_cot_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t,True, t, hop=1))
                        ###Create irrelevant second-hop fact
                        unrel_secondhop_entity_1 = random.choice(entities)
                        unrel_secondhop_entity_2 = random.choice(entities)
                        while (unrel_secondhop_entity_1 == b) or (unrel_secondhop_entity_1 == ent):
                            unrel_secondhop_entity_1 = random.choice(entities)
                        counterfactual_test_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop =2) if (not cot) else form_item_cot_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t, True,t, hop =2))
    return entities, relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_ood, counterfactual_train_facts, counterfactual_test_facts,factual_train_example,factual_test_example
def build_dataset_with_pretrain_held_out_cf_ent(num_entities, num_relations, out_degree=20, args={}, split_train_inferred=False, generate_in_prompt_cf = True, cot = False):
    entities = ["<e_{}>".format(i) for i in range(num_entities)]
    ind2entity, entity2ind = build_dicts(entities)
    relations = ["<r_{}>".format(i) for i in range(num_relations)]
    ind2relation, relation2ind = build_dicts(relations)
    atomic_dict = dict()   # maps a head entity to a list of (r, t) pairs
    atomic_facts = []
    atomics = []
    for i in tqdm(range(num_entities)):
        # for each subject entity, randomly select some outgoing relations to some random object entity
        num_rows = args.out_degree
        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()
        for row_idx in selected_rows:
            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)
            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]
            atomic_facts.append(form_items([h, r], t))
            atomics.append((h,r,t))
            if h not in atomic_dict:
                atomic_dict[h] = []
            atomic_dict[h].append((r, t))
    print(len(atomics))
    if not split_train_inferred:
        inferred_facts = []
        for ent in tqdm(entities):
            for (r1, b) in atomic_dict[ent]:
                for (r2, t) in atomic_dict[b]:
                    inferred_facts.append(form_items([ent, r1, r2], t))
        return entities, relations, atomic_facts, inferred_facts
    # split ID/OOD
    OOD_ratio = args.ood_fraction
    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))
    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)
    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]
    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]
    train_inferred_facts, test_inferred_iid, test_inferred_ood = [], [], []
    for ent in tqdm(entities):
        for (r1, b) in atomic_dict[ent]:
            for (r2, t) in atomic_dict[b]:
                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                    if (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:
                        test_inferred_ood.append(form_items([ent, r1, r2], t))
                    continue
                if np.random.uniform() > 0.005:
                    if args.cf_fmt_at_pt:
                        if np.random.uniform()>0.5:
                            train_inferred_facts.append(form_item_counterfactual([ent, r1, b, ent, r1, r2], t, unrel = False,original_t=t, hop=1))
                        else:
                            train_inferred_facts.append(form_item_counterfactual([b, r2, t, ent, r1, r2], t,  unrel = False, original_t=t, hop=2))
                    else:
                        train_inferred_facts.append(form_item_counterfactual([ent, r1, r2], t, unrel = False,original_t=t, hop=1))
                else:
                    if args.cf_fmt_at_pt:
                        if np.random.uniform()>0.5:
                            test_inferred_iid.append(form_item_counterfactual([ent, r1, b, ent, r1, r2], t, unrel=False, original_t=t, hop=1))
                        else:
                            test_inferred_iid.append(form_item_counterfactual([b, r2, t, ent, r1, r2], t, unrel = False , original_t = t, hop=2))
                    else:
                        test_inferred_iid.append(form_item_counterfactual([ent, r1, r2], t, unrel = False,original_t=t, hop=1))
    counterfactual_train_facts, counterfactual_test_facts = [], []
    cf_ood_train, cf_train_id = split(list(ID_facts), round(len(ID_facts)*args.cf_ood_fraction))
    cf_ood_train, cf_train_id = set(cf_ood_train), set(cf_train_id)

                        

    if generate_in_prompt_cf:
        for ent in tqdm(entities):
            for (r1, b) in atomic_dict[ent]: 
                for (r2, t) in atomic_dict[b]:
                    if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                        continue 
                    if ((ent, r1, b) in cf_train_id) and ((b, r2, t) in cf_train_id):
                        counterfactual_found = False
                        while not counterfactual_found:
                            bridge_cf = random.choice(entities)
                            bridges = atomic_dict[bridge_cf].copy()
                            random.shuffle(bridges)
                            for (rsub,tsub) in bridges:
                                if (rsub==r2) and (tsub!=t):
                                    if np.random.uniform()<args.cf_train_prob:
                                        counterfactual_train_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1) if (not cot) else form_item_cot_counterfactual([ent, r1, bridge_cf, ent, r1, r2], [ent, r1, bridge_cf, rsub, tsub], tsub,False,t, hop=1))
                                    else:
                                        counterfactual_test_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1) if (not cot) else form_item_cot_counterfactual([ent, r1, bridge_cf, ent, r1, r2], [ent, r1, bridge_cf, rsub, tsub], tsub,False,t, hop=1))
                                    counterfactual_found = True 
                                    break 
                            if counterfactual_found:break 
                        ###Add a counterfactual for the "second hop"
                        second_counterfactual_found = False
                        second_ent_sub = random.choice(entities)
                        while second_ent_sub == t:
                            second_ent_sub = random.choice(entities)
                        if np.random.uniform()<args.cf_train_prob:
                            counterfactual_train_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2) if (not cot) else form_item_cot_counterfactual([b,r2, second_ent_sub, ent, r1, r2],[ent,r1,b,r2,second_ent_sub], second_ent_sub, False, t, hop=2))
                        else: 
                            counterfactual_test_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2) if (not cot) else form_item_cot_counterfactual([b,r2, second_ent_sub, ent, r1, r2],[ent,r1,b,r2,second_ent_sub], second_ent_sub, False, t, hop=2))
                        ##Randomly select an unrelated entity 
                        unrel_entity_1 = random.choice(entities)
                        unrel_entity_2 = random.choice(entities)
                        while (unrel_entity_1 == b) or (unrel_entity_1 == ent):
                            unrel_entity_1 = random.choice(entities)
                        if np.random.uniform()<args.cf_train_prob:
                            counterfactual_train_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t, True, t, hop=1) if (not cot) else form_item_cot_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], [ent,r1,b,r2,t], t, True, t, hop=1))
                        else:
                            counterfactual_test_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t,True, t, hop=1) if (not cot) else form_item_cot_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], [ent,r1,b,r2,t], t, True, t, hop=1))
                        ###Create irrelevant second-hop fact
                        unrel_secondhop_entity_1 = random.choice(entities)
                        unrel_secondhop_entity_2 = random.choice(entities)
                        while (unrel_secondhop_entity_1 == b) or (unrel_secondhop_entity_1 == ent):
                            unrel_secondhop_entity_1 = random.choice(entities)
                        if np.random.uniform()<args.cf_train_prob:
                            counterfactual_train_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2, unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop = 2) if (not cot) else form_item_cot_counterfactual([unrel_secondhop_entity_1, r2, unrel_secondhop_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t, True,t, hop = 2))
                        else: 
                            counterfactual_test_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop =2) if (not cot) else form_item_cot_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t, True,t, hop =2))
                    else:
                        counterfactual_found = False
                        while not counterfactual_found:
                            bridge_cf = random.choice(entities)
                            bridges = atomic_dict[bridge_cf].copy()
                            random.shuffle(bridges)
                            for (rsub,tsub) in bridges:
                                if (rsub==r2) and (tsub!=t):
                                    counterfactual_test_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1) if (not cot) else form_item_cot_counterfactual([ent, r1, bridge_cf, ent, r1, r2],[ent, r1, bridge_cf, rsub, tsub], tsub,False,t, hop=1)) 
                                    counterfactual_found = True 
                                    break 
                            if counterfactual_found:break 
                        ###Add a counterfactual for the "second hop"
                        second_counterfactual_found = False
                        second_ent_sub = random.choice(entities)
                        while second_ent_sub == t:
                            second_ent_sub = random.choice(entities)
                        counterfactual_test_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2) if (not cot) else form_item_cot_counterfactual([b,r2, second_ent_sub, ent, r1, r2],[ent,r1,b,r2,second_ent_sub], second_ent_sub, False, t, hop=2))
                        ##Randomly select an unrelated entity 
                        unrel_entity_1 = random.choice(entities)
                        unrel_entity_2 = random.choice(entities)
                        while (unrel_entity_1 == b) or (unrel_entity_1 == ent):
                            unrel_entity_1 = random.choice(entities)
                        counterfactual_test_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t,True, t, hop=1) if (not cot) else form_item_cot_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t,True, t, hop=1))
                        ###Create irrelevant second-hop fact
                        unrel_secondhop_entity_1 = random.choice(entities)
                        unrel_secondhop_entity_2 = random.choice(entities)
                        while (unrel_secondhop_entity_1 == b) or (unrel_secondhop_entity_1 == ent):
                            unrel_secondhop_entity_1 = random.choice(entities)
                        counterfactual_test_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop =2) if (not cot) else form_item_cot_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2],[ent,r1,b,r2,t], t, True,t, hop =2))
    return entities, relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_ood, counterfactual_train_facts, counterfactual_test_facts
def build_dataset_with_pretrain(num_entities, num_relations, out_degree=20, args={}, split_train_inferred=False, generate_in_prompt_cf = True):
    entities = ["<e_{}>".format(i) for i in range(num_entities)]
    ind2entity, entity2ind = build_dicts(entities)
    relations = ["<r_{}>".format(i) for i in range(num_relations)]
    ind2relation, relation2ind = build_dicts(relations)
    atomic_dict = dict()   # maps a head entity to a list of (r, t) pairs
    atomic_facts = []
    atomics = []
    for i in tqdm(range(num_entities)):
        # for each subject entity, randomly select some outgoing relations to some random object entity
        num_rows = args.out_degree
        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()
        for row_idx in selected_rows:
            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)
            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]
            atomic_facts.append(form_items([h, r], t))
            atomics.append((h,r,t))
            if h not in atomic_dict:
                atomic_dict[h] = []
            atomic_dict[h].append((r, t))
    print(len(atomics))
    if not split_train_inferred:
        inferred_facts = []
        for ent in tqdm(entities):
            for (r1, b) in atomic_dict[ent]:
                for (r2, t) in atomic_dict[b]:
                    inferred_facts.append(form_items([ent, r1, r2], t))
        return entities, relations, atomic_facts, inferred_facts
    # split ID/OOD
    OOD_ratio = args.ood_fraction
    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))
    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)
    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]
    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]
    train_inferred_facts, test_inferred_iid, test_inferred_ood = [], [], []
    for ent in tqdm(entities):
        for (r1, b) in atomic_dict[ent]:
            for (r2, t) in atomic_dict[b]:
                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                    if (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:
                        test_inferred_ood.append(form_items([ent, r1, r2], t))
                    continue
                if np.random.uniform() > 0.005:
                    if args.cf_fmt_at_pt:
                        if np.random.uniform()>0.5:
                            train_inferred_facts.append(form_item_counterfactual([ent, r1, b, ent, r1, r2], t, unrel = False,original_t=t, hop=1))
                        else:
                            train_inferred_facts.append(form_item_counterfactual([b, r2, t, ent, r1, r2], t,  unrel = False, original_t=t, hop=2))
                    else:
                        train_inferred_facts.append(form_item_counterfactual([ent, r1, r2], t, unrel = False,original_t=t, hop=1))

                else:
                    if args.cf_fmt_at_pt:
                        if np.random.uniform()>0.5:
                            test_inferred_iid.append(form_item_counterfactual([ent, r1, b, ent, r1, r2], t, unrel=False, original_t=t, hop=1))
                        else:
                            test_inferred_iid.append(form_item_counterfactual([b, r2, t, ent, r1, r2], t, unrel = False , original_t = t, hop=2))
                    else:
                        test_inferred_iid.append(form_item_counterfactual([ent, r1, r2], t, unrel=False, original_t=t, hop=1))
    counterfactual_train_facts, counterfactual_test_facts = [], []
    if generate_in_prompt_cf:
        for ent in tqdm(entities):
            for (r1, b) in atomic_dict[ent]: 
                for (r2, t) in atomic_dict[b]:
                    if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:
                        continue 
                    counterfactual_found = False
                    while not counterfactual_found:
                        bridge_cf = random.choice(entities)
                        bridges = atomic_dict[bridge_cf].copy()
                        random.shuffle(bridges)
                        for (rsub,tsub) in bridges:
                            if (rsub==r2) and (tsub!=t):
                                if np.random.uniform()<args.cf_train_prob:
                                    counterfactual_train_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1))
                                else:
                                    counterfactual_test_facts.append(form_item_counterfactual([ent, r1, bridge_cf, ent, r1, r2], tsub,False,t, hop=1))
                                counterfactual_found = True 
                                break 
                        if counterfactual_found:break 
                    ###Add a counterfactual for the "second hop"
                    second_counterfactual_found = False
                    second_ent_sub = random.choice(entities)
                    while second_ent_sub == t:
                        second_ent_sub = random.choice(entities)
                    if np.random.uniform()<args.cf_train_prob:
                        counterfactual_train_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2))
                    else: 
                        counterfactual_test_facts.append(form_item_counterfactual([b,r2, second_ent_sub, ent, r1, r2], second_ent_sub, False, t, hop=2))
                    ##Randomly select an unrelated entity 
                    unrel_entity_1 = random.choice(entities)
                    unrel_entity_2 = random.choice(entities)
                    while (unrel_entity_1 == b) or (unrel_entity_1 == ent):
                        unrel_entity_1 = random.choice(entities)
                    if np.random.uniform()<args.cf_train_prob:
                        counterfactual_train_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t, True, t, hop=1))
                    else:
                        counterfactual_test_facts.append(form_item_counterfactual([unrel_entity_1, r1, unrel_entity_2, ent, r1, r2], t,True, t, hop=1))
                    ###Create irrelevant second-hop fact
                    unrel_secondhop_entity_1 = random.choice(entities)
                    unrel_secondhop_entity_2 = random.choice(entities)
                    while (unrel_secondhop_entity_1 == b) or (unrel_secondhop_entity_1 == ent):
                        unrel_secondhop_entity_1 = random.choice(entities)
                    if np.random.uniform()<args.cf_train_prob:
                        counterfactual_train_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2, unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop = 2))
                    else: 
                        counterfactual_test_facts.append(form_item_counterfactual([unrel_secondhop_entity_1, r2,unrel_secondhop_entity_2, ent, r1, r2], t, True,t, hop =2))
    return entities, relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_ood, counterfactual_train_facts, counterfactual_test_facts
    
def main(args):
    if not args.cf_ood_fraction:
        train_entities, train_relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_facts, cf_train, cf_test = build_dataset_with_pretrain(args.num_entities, args.num_relations, out_degree = args.out_degree, args = args, split_train_inferred=True)
    else:
        if args.multi == 0:
            print("Not Running Multi")
            train_entities, train_relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_facts, cf_train, cf_test = build_dataset_with_pretrain_held_out_cf_ent(args.num_entities, args.num_relations, out_degree = args.out_degree, args = args, split_train_inferred=True,cot=bool(args.cot_fmt))
        else:
            print("Running Multi")
            has_factual_cot_ex = False
            train_entities, train_relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_facts, cf_train, cf_test,fact_train_cot, fact_test_cot = build_dataset_with_pretrain_held_out_cf_ent_multi(args.num_entities, args.num_relations, out_degree = args.out_degree, args = args, split_train_inferred=True,cot=bool(args.cot_fmt),multi=args.multi)
    vocab = []
    vocab = vocab + train_entities + train_relations
    # special tokens
    vocab = vocab + ["<suppose>","<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]
    assert len(vocab) == len(set(vocab))
    print("vocab size:", len(vocab))
    test_size = 3000
    id_atomic_facts_ds = choose(id_atomic_facts, test_size)
    ood_atomic_facts_ds = choose(ood_atomic_facts, test_size)
    test_inferred_iid = choose(test_inferred_iid, test_size)
    test_inferred_facts_ds = choose(test_inferred_facts, test_size)
    all_atomics = id_atomic_facts + ood_atomic_facts
    len(all_atomics)
    for phi in [args.phi]:
        dataset_name = f"{args.dataset_name}.{args.num_entities}.{args.num_relations}.{phi}"
        os.makedirs("data/procedural/{}".format(dataset_name), exist_ok=True)
        train_inferred_facts_ds = choose(train_inferred_facts, round(phi * len(id_atomic_facts)))
        probes = []
        for item in id_atomic_facts_ds:
            probes.append(deepcopy(item))
            probes[-1]["type"] = "id_atomic"
        for item in ood_atomic_facts_ds:
            probes.append(deepcopy(item))
            probes[-1]["type"] = "ood_atomic"
        for item in choose(train_inferred_facts_ds, test_size):
            probes.append(deepcopy(item))
            probes[-1]['type'] = 'train_inferred'
        for item in test_inferred_iid:
            probes.append(deepcopy(item))
            probes[-1]['type'] = 'test_inferred_iid'
        for item in test_inferred_facts_ds:
            probes.append(deepcopy(item))
            probes[-1]["type"] = "test_inferred_ood"
        with open("data/procedural/{}/train.json".format(dataset_name), "w", encoding='utf-8') as f:
            json.dump(all_atomics + train_inferred_facts_ds, f)
        with open("data/procedural/{}/valid.json".format(dataset_name), "w", encoding='utf-8') as f:
            json.dump(test_inferred_facts_ds, f)
        with open("data/procedural/{}/cf_train.json".format(dataset_name), "w", encoding='utf-8') as f:
            json.dump(cf_train, f)
        with open("data/procedural/{}/cf_test.json".format(dataset_name), "w", encoding='utf-8') as f:
            json.dump(cf_test, f)
        with open("data/procedural/{}/cot_fact_train.json".format(dataset_name), "w",encoding='utf-8') as f:
            json.dump(fact_train_cot, f)
        with open("data/procedural/{}/cot_fact_test.json".format(dataset_name), "w",encoding='utf-8') as f:
            json.dump(fact_test_cot, f)
        with open("data/procedural/{}/test.json".format(dataset_name), "w", encoding='utf-8') as f:
            json.dump(probes, f)
        # add vocab
        with open("data/procedural/{}/vocab.json".format(dataset_name), "w", encoding='utf-8') as f:
            json.dump(vocab, f)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--phi", type = int, default = 18.0)
    parser.add_argument("--num_entities", type = int, default = 2000)
    parser.add_argument('--num_relations', type = int,default = 200)
    parser.add_argument("--out_degree", type = int, default = 20)
    parser.add_argument("--cf_train_prob", type = float, default=0.5)
    parser.add_argument("--dataset_name", type = str)
    parser.add_argument("--split_train_inferred", type = int, default = 1)
    parser.add_argument("--cf_train_reps",type=int, default =0)
    parser.add_argument("--ood_fraction", type = float, default = 0.05)
    parser.add_argument("--generate_cf", type = int, default = 1)
    parser.add_argument("--cf_ood_fraction", type=float, default = None)
    parser.add_argument("--multi", type = int, default = 0)
    parser.add_argument("--cf_fmt_at_pt", type = int, default = 0)
    parser.add_argument("--cot_fmt", type = int, default = 0)
    args = parser.parse_args()
    main(args)
