import argparse 
import os 
import pandas as pd 
from wikidata.client import Client
import tqdm 
import random
import numpy as np
import pickle 
import datasets
def load_kg(filename):
    with open("initial_kg", "rb") as kg_file:
        kg_dict = pickle.load(kg_file)
        relations = kg_dict['relations']
        entities = kg_dict['entities']
    return entities, relations
def generate_all_2_hops(entity_1_name_id,entity_obj, entities, relations):
    all_mh_facts = []
    for rel_1 in entity_obj.keys():
        second_hop_entity_name = entity_obj[rel_1][0]
        second_hop_entity_id = entity_obj[rel_1][1]
        second_hop_entity_type = relations[rel_1]["output_type"]
        try:
            second_hop_entity_obj = entities[second_hop_entity_type][(second_hop_entity_name, second_hop_entity_id)]
        except KeyError:
            print(entity_1_name_id[0])
            continue
        for rel_2 in second_hop_entity_obj.keys():
            final_entity = second_hop_entity_obj[rel_2]
            all_mh_facts.append(((entity_1_name_id, rel_1, (second_hop_entity_name, second_hop_entity_id)), ((second_hop_entity_name, second_hop_entity_id), rel_2, final_entity)))
    return all_mh_facts
def generate_mh(entities, relations):
    all_two_hops = []
    for entity_type in entities.keys():
        print(entity_type)
        for entity_1_name, entity_1_id in entities[entity_type].keys():
            all_two_hops+=generate_all_2_hops((entity_1_name, entity_1_id), entities[entity_type][(entity_1_name, entity_1_id)], entities, relations)
    return all_two_hops 
def generate_replacement_entities_data(entities):
    replacement_entities = {entity_type:list(set([x for x in entities[entity_type]])) for entity_type in entities.keys()}
    return replacement_entities

def generate_dataset(all_two_hop_pairs, entities, relations, replacement_entity_dict, p_relevant=0.5, p_hop_1=0.5):
    dataset = []
    for two_hop_relation in all_two_hop_pairs:
        hop_1_relation_obj = relations[two_hop_relation[0][1]]
        hop_1_relation_id = two_hop_relation[0][1]
        hop_2_relation_obj = relations[two_hop_relation[1][1]]
        hop_2_relation_id = two_hop_relation[1][1]
        if (len(replacement_entity_dict[hop_1_relation_obj['output_type']])==0) or (len(replacement_entity_dict[hop_2_relation_obj['output_type']])==0):
            continue
        if np.random.rand()<p_relevant:
            if np.random.rand() <p_hop_1:
                #Hop 1, Relevant CF
                substitute_hop_1_entity = random.choice(replacement_entity_dict[hop_1_relation_obj['output_type']])
                while (hop_2_relation_id not in entities[hop_1_relation_obj['output_type']][substitute_hop_1_entity].keys()) or (substitute_hop_1_entity == two_hop_relation[0][-1]):
                    substitute_hop_1_entity = random.choice(replacement_entity_dict[hop_1_relation_obj['output_type']])
                cf_answer = entities[hop_1_relation_obj['output_type']][substitute_hop_1_entity][hop_2_relation_id]
                this_datapoint_cols = {'question':f"Suppose that {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} was {substitute_hop_1_entity[0]}. What would {hop_2_relation_obj['mh_phrase']} {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} be?",
                    "answer": f"{cf_answer[0]}", 
                    "hop1_rel": (hop_1_relation_obj['relation_name'],hop_1_relation_id),
                    "hop2_rel": (hop_2_relation_obj['relation_name'],hop_2_relation_id),
                    "relevant_cf":True, 
                    "cf_hop":1, 
                    "CoT": f"""The provided hypothetical indicates that {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} is {substitute_hop_1_entity[0]}. Additionally, {hop_2_relation_obj['mh_phrase']} {substitute_hop_1_entity[0]} is {cf_answer[0]}. Therefore the answer is <<{cf_answer[0]}>>"""}    
                dataset.append(this_datapoint_cols)
            else:
                #Hop 2, Relevant CF
                substitute_hop_2_entity = random.choice(replacement_entity_dict[hop_2_relation_obj['output_type']])
                while (substitute_hop_2_entity == two_hop_relation[1][-1]):
                    substitute_hop_2_entity = random.choice(replacement_entity_dict[hop_2_relation_obj['output_type']])
                cf_answer = substitute_hop_2_entity
                this_datapoint_cols = {'question':f"Suppose that {hop_2_relation_obj['mh_phrase']} {two_hop_relation[1][0][0]} was {substitute_hop_2_entity[0]}. What would {hop_2_relation_obj['mh_phrase']} {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} be?",
                    "answer": f"{cf_answer[0]}", 
                    "hop1_rel": (hop_1_relation_obj['relation_name'],hop_1_relation_id),
                    "hop2_rel": (hop_2_relation_obj['relation_name'],hop_2_relation_id),
                    "relevant_cf":True, 
                    "cf_hop":2, 
                    "CoT": f"""The provided hypothetical indicates that {hop_2_relation_obj['mh_phrase']} {two_hop_relation[1][0][0]} is {substitute_hop_2_entity[0]}. Additionally, {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} is {two_hop_relation[1][0][0]}. Therefore the answer is <<{cf_answer[0]}>>"""}    
                dataset.append(this_datapoint_cols)
        else:
            ##Irrelevant cases
            if np.random.rand() < p_hop_1:
                counterfactual_relation_obj = relations[two_hop_relation[0][1]]
                input_type = counterfactual_relation_obj['input_type']
                output_type = counterfactual_relation_obj['output_type']
                random_counterfactual_entity_1 = random.choice(replacement_entity_dict[input_type])
                random_counterfactual_entity_2 = random.choice(replacement_entity_dict[output_type])
                while (random_counterfactual_entity_1 == two_hop_relation[0][0]) or (random_counterfactual_entity_1 == random_counterfactual_entity_2):
                    random_counterfactual_entity_1 = random.choice(replacement_entity_dict[input_type])
                    random_counterfactual_entity_2 = random.choice(replacement_entity_dict[output_type])
                this_datapoint_cols = {'question':f"Suppose that {counterfactual_relation_obj['mh_phrase']} {random_counterfactual_entity_1[0]} was {random_counterfactual_entity_2[0]}. What would {hop_2_relation_obj['mh_phrase']} {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} be?",
                    "answer": f"{two_hop_relation[-1][-1][0]}", 
                    "hop1_rel": (hop_1_relation_obj['relation_name'],hop_1_relation_id),
                    "hop2_rel": (hop_2_relation_obj['relation_name'],hop_2_relation_id),
                    "relevant_cf":False, 
                    "cf_hop":1, 
                    "CoT": f"""We have that {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} is {two_hop_relation[0][-1][0]}. Additionally, {hop_2_relation_obj['mh_phrase']} of {two_hop_relation[0][-1][0]} is {two_hop_relation[-1][-1][0]}. Therefore the answer is <<{two_hop_relation[-1][-1][0]}>>"""}   
                dataset.append(this_datapoint_cols)
            else:
                counterfactual_relation_obj = relations[two_hop_relation[1][1]]
                input_type = counterfactual_relation_obj['input_type']
                output_type = counterfactual_relation_obj['output_type']
                random_counterfactual_entity_1 = random.choice(replacement_entity_dict[input_type])
                random_counterfactual_entity_2 = random.choice(replacement_entity_dict[output_type])
                while (random_counterfactual_entity_1 == two_hop_relation[1][0]) or (random_counterfactual_entity_2 == two_hop_relation[1][-1]):
                    random_counterfactual_entity_1 = random.choice(replacement_entity_dict[input_type])
                    random_counterfactual_entity_2 = random.choice(replacement_entity_dict[output_type])
                this_datapoint_cols = {'question':f"Suppose that {counterfactual_relation_obj['mh_phrase']} {random_counterfactual_entity_1[0]} was {random_counterfactual_entity_2[0]}. What would {hop_2_relation_obj['mh_phrase']} {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} be?",
                    "answer": f"{two_hop_relation[-1][-1][0]}", 
                    "hop1_rel": (hop_1_relation_obj['relation_name'],hop_1_relation_id),
                    "hop2_rel": (hop_2_relation_obj['relation_name'],hop_2_relation_id),
                    "relevant_cf":False, 
                    "cf_hop":2, 
                    "CoT": f"""We have that {hop_1_relation_obj['mh_phrase']} {two_hop_relation[0][0][0]} is {two_hop_relation[0][-1][0]}. Additionally, {hop_2_relation_obj['mh_phrase']} {two_hop_relation[0][-1][0]} is {two_hop_relation[-1][-1][0]}. Therefore the answer is <<{two_hop_relation[-1][-1][0]}>>"""}   
                dataset.append(this_datapoint_cols)
    return dataset
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--kg_file", type = str, default="initial_kg")
    parser.add_argument("--p_relevant", type = float, default = 0.5)
    parser.add_argument("--p_hop1", type = float, default = 0.5)
    parser.add_argument('--outfile', type = str, default = "./wikidata_cf/")
    args = parser.parse_args()
    entities, relations = load_kg(args.kg_file)
    all_mh_facts = generate_mh(entities, relations)
    entity_replacement_pools = generate_replacement_entities_data(entities)
    cf_dataset = generate_dataset(all_two_hop_pairs = all_mh_facts, entities = entities, relations = relations, replacement_entity_dict = entity_replacement_pools, p_relevant = args.p_relevant, p_hop_1=args.p_hop1)
    ds_obj = datasets.Dataset.from_pandas(pd.DataFrame(data=cf_dataset))
    ds_obj.save_to_disk(args.outfile)
