import pandas as pd 
from wikidata.client import Client
import tqdm 
import argparse 
import os 
import pickle
ENTITY_TYPES = ["person", "country", "writtenwork", "language"]
relations = {"P36":{"relation_name":"capital", "mh_phrase":"the capital of", "single_question_template":"What is the capital of [ENTITY]?", "input_type":"country", "output_type":"city"},
             "P37":{"relation_name":"language", "mh_phrase":"the language of", "single_question_template":"What is the language of [ENTITY]?", "input_type":"country", "output_type":"language"},
             "P38":{"relation_name":"currency", "mh_phrase":"the currency of", "single_question_template": "What is the currency of [ENTITY]", "input_type":"country", "output_type":"currency"}, 
             "P27":{"relation_name":"country_citizen", "mh_phrase":"the country of citizenship of", "single_question_template": "What is the country of citizenship of [ENTITY]?", "input_type":"person", "output_type":"country"}, 
             "P50":{"relation_name":"author", "mh_phrase":"the author of", "single_question_template": "Who is the author of [ENTITY]?", "input_type":"writtenwork", "output_type":"person"},
}
def get_seed_from_pqa():
    ds = pd.read_csv("popQA.tsv", sep = "\t")
    people_q_ids = [x.split("/")[-1] for x in ds.loc[(ds.prop.isin(["occupation", "place of birth", "father", "mother"]))]["s_uri"].tolist()]
    writtenwork_q_ids = [x.split("/")[-1] for x in ds.loc[(ds.prop.isin(["author"]))]["s_uri"].tolist()]
    entity_seed_list = {"person":people_q_ids,"writtenwork":writtenwork_q_ids, "country":[],"language":[], "currency":[], "city":[]}
    return entity_seed_list
def load_existing_progress(filename):
    with open(filename, "rb") as seed_file:
        seed_ds = pickle.load(seed_file)
    return seed_ds 
def generate_kg(relations, entity_seed_list, save_file_name, entity_dicts = None):
    client = Client()
    props = {}
    for prop_id in relations.keys():
        relation_object  = client.get(prop_id)
        props[prop_id] = relation_object
    if entity_dicts is None:
        entity_dicts = {typ:{} for typ in entity_seed_list.keys()}
    for entity_type in ["writtenwork", "person", "country", "language", "currency", "city"]:
        for entity_id in tqdm.tqdm(list(set(entity_seed_list[entity_type]))):
            if entity_id in entity_dicts[entity_type].keys():
                continue 
            entity_object = client.get(entity_id)
            entity_dicts[entity_type][(str(entity_object.label), entity_id)] = {}
            for prop_id in props.keys():
                this_prop = entity_object.getlist(key = props[prop_id])
                if len(this_prop)==0:continue 
                this_prop[0].load()
                entity_dicts[entity_type][(str(entity_object.label), entity_id)][prop_id] = (str(this_prop[0].label),this_prop[0].id) ###Add 
                entity_seed_list[relations[prop_id]["output_type"]].append(this_prop[0].id)
        with open(save_file_name, "wb") as save_file: ###Periodically
            dump_dict = {"relations":relations, "entities":entity_dicts, "seeds":entity_seed_list}
            pickle.dump(dump_dict, save_file)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_file", type=str, default = None)
    args = parser.parse_args()
    if os.path.exists(args.save_file):
        progress_dict = load_existing_progress(args.save_file)
        relations = progress_dict["relations"]
        entity_dicts = progress_dict["entities"]
        seeds = progress_dict['seeds']
    else:
        entity_dicts = None 
        seeds = get_seed_from_pqa()
    generate_kg(relations=relations,entity_seed_list = seeds, save_file_name=args.save_file,entity_dicts=entity_dicts)