import os
import copy
import json
import faiss
import argparse
import numpy as np
import _pickle as pkl
from struct import *

from sentence_transformers import SentenceTransformer

from tqdm import tqdm, trange
from collections import defaultdict

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--datasets", type=str, default='WebQSP', help='WebQSP/MetaQA')

    parser.add_argument("--data_path", type=str, default="/data/home/datasets")
    parser.add_argument("--save_path", type=str, default="../processed")

    args = parser.parse_args()
    return args

def preprocess_webqsp(args):
    # extract useful information from WebQSP dataset
    def get_fid(s):
        return '/' + s.replace('.', '/')

    def get_answers(question):
        """extract unique answers from question parses."""
        answers = set()
        for parse in question["Parses"]:
            for answer in parse["Answers"]:
                answers.add((answer["AnswerArgument"], answer["EntityName"]))
        return answers

    def get_entities(question):
        """extract oracle entities from question parses."""
        entities = set()
        for parse in question["Parses"]:
            if parse["TopicEntityMid"] is not None:
                entities.add((parse["TopicEntityMid"], parse["TopicEntityName"], parse["PotentialTopicEntityMention"]))
        return entities

    webqsp_questions = defaultdict(list)
    for split in ['train', 'test']:
        webqsp_file = os.path.join(args.qa_data_path, f"webqsp_{split}.json")
        webqsp_data = json.load(open(webqsp_file))
        for question in webqsp_data["Questions"]:
            if question['Parses'][0]['InferentialChain']:
                q_obj = {
                    "QuestionId": question["QuestionId"],
                    "RawQuestion": str.lower(question["RawQuestion"]),
                    "ProcessedQuestion": str.lower(question["ProcessedQuestion"]),
                    "OracleEntities": [
                        {"TopicEntityMid": get_fid(entity[0]),
                         "TopicEntityName": entity[1],
                         "PotentialTopicEntityMention": entity[2],
                         }
                        for entity in get_entities(question)
                    ],
                    "InferentialChain": question['Parses'][0]['InferentialChain'],
                    "Answers": [
                        {"AnswerArgument": get_fid(answer[0])
                        if answer[0].startswith("m.") or answer[0].startswith("g.") else answer[0],
                         "EntityName": answer[1]}
                        for answer in get_answers(question)
                    ]
                }
                webqsp_questions[split].append(q_obj)

    print(
        f"num of training questions: {len(webqsp_questions['train'])}, num of testing questions: {len(webqsp_questions['test'])}")

    # obtain question template and relations
    webqsp_entities = defaultdict(dict)
    webqsp_relations = defaultdict(dict)
    webqsp_samples = defaultdict(list)
    webqsp_question_template = defaultdict(set)

    for split in ['train', 'test']:
        for question in webqsp_questions[split]:
            question_text = question['ProcessedQuestion']
            question_template = copy.copy(question_text)
            question_entities = set()
            for entity in question['OracleEntities']:
                if str.lower(entity['PotentialTopicEntityMention']) in question_template:
                    entity_name = str.lower(entity['PotentialTopicEntityMention'])
                    question_template = question_template.replace(entity_name, '<mask>')
                    question_entities.add(entity_name)
                elif str.lower(entity['TopicEntityName']) in question_template:
                    entity_name = str.lower(entity['TopicEntityName'])
                    question_template = question_template.replace(entity_name, '<mask>')
                    question_entities.add(entity_name)
                else:
                    question_template = question_template.replace('. ', '.')
                    question_template = question_template.replace(' .', '.')
                    entity['PotentialTopicEntityMention'] = entity['PotentialTopicEntityMention'].replace(' . ',
                                                                                                          '.')
                    entity['PotentialTopicEntityMention'] = entity['PotentialTopicEntityMention'].replace(' .', '.')
                    entity['PotentialTopicEntityMention'] = entity['PotentialTopicEntityMention'].replace(" ' ",
                                                                                                          "'")
                    entity['PotentialTopicEntityMention'] = entity['PotentialTopicEntityMention'].replace(" '", "'")
                    entity['PotentialTopicEntityMention'] = entity['PotentialTopicEntityMention'].replace(" 's",
                                                                                                          "'s")

                    if str.lower(entity['PotentialTopicEntityMention']) in question_template:
                        entity_name = str.lower(entity['PotentialTopicEntityMention'])
                        question_template = question_template.replace(entity_name, '<mask>')
                        question_entities.add(entity_name)
                    # 之前没出现过
                    elif str.lower(entity['PotentialTopicEntityMention']) not in question_entities:
                        print(question_template)
                        print(entity['PotentialTopicEntityMention'])

                if entity_name not in webqsp_entities:
                    webqsp_entities[entity_name] = [entity]
                else:
                    if entity not in webqsp_entities[entity_name]:
                        webqsp_entities[entity_name].append(entity)

            webqsp_samples[split].append(
                [question['Answers'], question['QuestionId'], question_text, question_template, list(question_entities), question['InferentialChain']])
            webqsp_question_template[split].add(question_template)

            # collect inferential chain
            if frozenset(question['InferentialChain']) not in webqsp_relations[split]:
                webqsp_relations[split][frozenset(question['InferentialChain'])] = [question_template]
            else:
                webqsp_relations[split][frozenset(question['InferentialChain'])].append(question_template)

    print(
        f"num of training relations: {len(webqsp_relations['train'])}, num of testing relations: {len(webqsp_relations['test'])}")

    # clean the question template
    # you can replace the question template with your predefined questions
    import pandas as pd

    webqsp_relation_dict = defaultdict(dict)
    for split in ['train', 'test']:
        # save to csv
        dataframe = pd.DataFrame({'logical_chains': list(webqsp_relations[split].keys()),
                                  'question_templates': list(webqsp_relations[split].values())})

        dataframe.to_csv(os.path.join(args.save_path, f'WebQSP_question_{split}.csv'), index=False, sep=',')

        # select a question template for each logical chain
        for logical_chain in webqsp_relations[split].keys():
            webqsp_relation_set = list(set(webqsp_relations[split][logical_chain]))
            webqsp_relation_counts = [webqsp_relations[split][logical_chain].count(tmp) for tmp in
                                      webqsp_relation_set]
            webqsp_relation_sorted = sorted(zip(webqsp_relation_counts, webqsp_relation_set), reverse=True)
            webqsp_relation_dict[split][logical_chain] = webqsp_relation_sorted[0][1]

    pkl.dump(webqsp_relation_dict, open(os.path.join(args.save_path, f'WebQSP_relations.pkl'), 'wb'))

    # prepare samples in json files
    for split in ['train', 'test']:

        demonstration_list = []
        for q_idx in range(len(webqsp_samples[split])):
            cur_ans, cur_qid, cur_question, cur_question_template, cur_seed_entities, cur_logical_chain = webqsp_samples[split][q_idx]

            demonstration = {}
            # instruction

            demonstration['qid'] = cur_qid

            demonstration['instruction'] = 'The AI assistant can parse the user input to several subquestions:'
            # input
            demonstration['input'] = cur_question
            # output
            if frozenset(cur_logical_chain) in webqsp_relation_dict['train']:
                demonstration['output'] = str(
                    '[{"question": "' + webqsp_relation_dict['train'][frozenset(cur_logical_chain)] +
                    '", "id": ' + str(0) +
                    ', "dep": [' + str(-1) +
                    '], "args": {"seed_entities": ["' + ','.join(cur_seed_entities) + '"]}}]')
            else:
                demonstration['output'] = str(
                    '[{"question": "' + webqsp_relation_dict['test'][frozenset(cur_logical_chain)] +
                    '", "id": ' + str(0) +
                    ', "dep": [' + str(-1) +
                    '], "args": {"seed_entities": ["' + ','.join(cur_seed_entities) + '"]}}]')

            qans = ','.join(
                [f"'{str.lower(term['EntityName'])}'" for term in cur_ans if term['EntityName']])

            demonstration['answers'] = qans
            # collect
            demonstration_list.append(demonstration)

        json.dump(demonstration_list, open(os.path.join(args.save_path, f'WebQSP_{split}.json'), 'w'))

    return webqsp_questions

def preprocess_metaqa(args):

    question_template = dict()
    question_template['movie_to_director'] = 'the film [mask] was directed by who?'
    question_template['director_to_movie'] = 'what films did [mask] direct?'
    question_template['movie_to_writer'] = 'the film [mask] was the written by who?'
    question_template['writer_to_movie'] = 'what films did [mask] write?'
    question_template['movie_to_actor'] = 'who acted in the film [mask]?'
    question_template['actor_to_movie'] = 'what films did [mask] act in?'
    question_template['movie_to_tags'] = 'what words can describe the film [mask]?'
    question_template['tags_to_movies'] = 'what film can be described by the word [mask]?'
    question_template['movie_to_year'] = 'what was the release year of the film [mask]?'
    question_template['movie_to_language'] = 'what is the language in the film [mask]?'
    question_template['movie_to_genre'] = 'what is the genre of the film [mask]?'
    question_template['movie_to_rating'] = 'what is the rating of the film [mask]?'
    question_template['movie_to_votes'] = 'how popular is the film [mask]?'

    def save_to_json(input_file):
        ctr = 0
        all_qa_instances = list()
        with open(input_file) as f_in:
            for line in tqdm(f_in):
                line = line.strip()
                qa_instance = {
                    "id": ctr,
                    "logical_chain": line,
                }
                all_qa_instances.append(qa_instance)
                ctr = ctr + 1

        return all_qa_instances

    prompt_mode = 'easy'
    answer_mode = 'hard'
    for mode in ['train', 'test']:
        json_file = json.load(open(os.path.join(args.qa_data_path, f'{mode}.json')))
        label_file = save_to_json(os.path.join(args.qa_data_path, f'qa_{mode}_qtype.txt'))

        demonstration_list = []
        for q_idx in range(len(json_file)):
            demonstration = {}

            # instruction
            if prompt_mode == 'hard':
                demonstration[
                    'instruction'] = 'The AI assistant can parse the user input to several subquestions: [{"question": question, "question_id": question_id, "dep": dependency_question_id, "args": {"seed_entities": [text or <GENERATED>-dep_id]}}]. The special tag "<GENERATED>-dep_id" refer to the one generated text in the dependency question and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite questions which generate a new text that the current question relies on. The "args" field must in ["text"], nothing else. The question MUST be selected from the following options:"the film [mask] was directed by who?","what films did [mask] direct?","the film [mask] was the written by who?","what films did [mask] write?","who acted in the film [mask]?","what films did [mask] act in?","what words can describe the film [mask]?","what film can be described by the word [mask]?","what was the release year of the film [mask]?","what is the language in the film [mask]?","what is the genre of the film [mask]?", "what is the rating of the film [mask]?","how popular is the film [mask]?".'
                # demonstration['instruction'] = 'The AI assistant can parse the user input to several subquestions: [{"question": question, "question_id": question_id, "dep": dependency_question_id, "args": {"seed_entities": [text or <GENERATED>-dep_id]}}]. The special tag "<GENERATED>-dep_id" refer to the one generated text in the dependency question and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite questions which generate a new text that the current question relies on. The "args" field must in ["text"], nothing else. The question MUST be selected from the following options:"what movies can be described by [mask]","which movies was [mask] the writer of","which films did [mask] direct","[mask] appears in which movies","who is the writer of [mask]","which topics is [mask] about","what is the rating of [mask]","what was the release date of [mask]","the film [mask] was directed by who","who starred in [mask]","how popular of a movie is [mask]", "what language is [mask] in","what is the genre of the film [mask]".'
            else:
                demonstration[
                    'instruction'] = 'The AI assistant can parse the user input to several subquestions:'

            # input
            cur_question = json_file[q_idx]['question']
            cur_seed_entities = json_file[q_idx]['seed_entities']
            for seed_entity in cur_seed_entities:
                # delete [
                start_idx = cur_question.find(seed_entity)
                cur_question = cur_question[:start_idx - 1] + cur_question[start_idx:]
                # deleta ]
                end_idx = cur_question.find(seed_entity) + len(seed_entity) - 1
                cur_question = cur_question[:end_idx + 1] + cur_question[end_idx + 2:]
            demonstration['input'] = cur_question + '?'

            # output
            if answer_mode == 'hard':
                answer_list = []
                cur_logical_chain = label_file[q_idx]['logical_chain']
                logical_chains = cur_logical_chain.split('_')[0::2]
                for logical_idx in range(len(logical_chains) - 1):
                    cur_question_template = question_template[
                        f'{logical_chains[logical_idx]}_to_{logical_chains[logical_idx + 1]}']

                    if logical_idx == 0:
                        answer_list.append('{"question": "' + cur_question_template +
                                           '", "id": ' + str(logical_idx) +
                                           ', "dep": [' + str(logical_idx - 1) +
                                           '], "args": {"seed_entities": ' + str(str(cur_seed_entities)) + '}}')
                    else:
                        answer_list.append('{"question": "' + cur_question_template +
                                           '", "id": ' + str(logical_idx) +
                                           ', "dep": [' + str(logical_idx - 1) +
                                           '], "args": {"seed_entities": ["<GENERATED>-' + str(
                            logical_idx - 1) + '"]}}')

                answers = ','.join(answer_list)
                demonstration['output'] = '[' + answers + '].'

            else:
                answer_list = []
                cur_logical_chain = label_file[q_idx]['logical_chain']
                logical_chains = cur_logical_chain.split('_')[0::2]
                for logical_idx in range(len(logical_chains) - 1):
                    cur_question_template = question_template[
                        f'{logical_chains[logical_idx]}_to_{logical_chains[logical_idx + 1]}']
                    answer_list.append(f'({logical_idx + 1}) {cur_question_template}')
                answers = ' '.join(answer_list)
                demonstration['output'] = answers

            # collect
            demonstration_list.append(demonstration)

        json.dump(demonstration_list,
                  open(os.path.join(args.save_path, f'MetaQA_{mode}.json'), 'w'))


def preprocess_webqsp_kg(args):
    # Get entity names from FastRDFStore
    # https://github.com/microsoft/FastRDFStore
    class BinaryStream:
        def __init__(self, base_stream):
            self.base_stream = base_stream

        def readByte(self):
            return self.base_stream.read(1)

        def readBytes(self, length):
            return self.base_stream.read(length)

        def readChar(self):
            return self.unpack('b')

        def readUChar(self):
            return self.unpack('B')

        def readBool(self):
            return self.unpack('?')

        def readInt16(self):
            return self.unpack('h', 2)

        def readUInt16(self):
            return self.unpack('H', 2)

        def readInt32(self):
            return self.unpack('i', 4)

        def readUInt32(self):
            return self.unpack('I', 4)

        def readInt64(self):
            return self.unpack('q', 8)

        def readUInt64(self):
            return self.unpack('Q', 8)

        def readFloat(self):
            return self.unpack('f', 4)

        def readDouble(self):
            return self.unpack('d', 8)

        def decode_from_7bit(self):
            """
            Decode 7-bit encoded int from str data
            """
            result = 0
            index = 0
            while True:
                byte_value = self.readUChar()
                result |= (byte_value & 0x7f) << (7 * index)
                if byte_value & 0x80 == 0:
                    break
                index += 1
            return result

        def readString(self):
            length = self.decode_from_7bit()
            return self.unpack(str(length) + 's', length)

        def writeBytes(self, value):
            self.base_stream.write(value)

        def writeChar(self, value):
            self.pack('c', value)

        def writeUChar(self, value):
            self.pack('C', value)

        def writeBool(self, value):
            self.pack('?', value)

        def writeInt16(self, value):
            self.pack('h', value)

        def writeUInt16(self, value):
            self.pack('H', value)

        def writeInt32(self, value):
            self.pack('i', value)

        def writeUInt32(self, value):
            self.pack('I', value)

        def writeInt64(self, value):
            self.pack('q', value)

        def writeUInt64(self, value):
            self.pack('Q', value)

        def writeFloat(self, value):
            self.pack('f', value)

        def writeDouble(self, value):
            self.pack('d', value)

        def writeString(self, value):
            length = len(value)
            self.writeUInt16(length)
            self.pack(str(length) + 's', value)

        def pack(self, fmt, data):
            return self.writeBytes(pack(fmt, data))

        def unpack(self, fmt, length=1):
            return unpack(fmt, self.readBytes(length))[0]

    class Relation:
        def __init__(self, line):
            if line is None:
                self.subj = self.rel = self.obj = None
                return
            e1, rel, e2 = line.strip().split(None, 2)
            e1 = self.canonicalize(e1)
            e2 = self.canonicalize(e2)
            self.subj = e1
            self.rel = rel
            self.obj = e2

        def __hash__(self):
            return hash((self.subj, self.rel, self.obj))

        def _filter_relation(self):
            # same criteria as GraftNet
            relation = self.rel
            if relation == "<fb:common.topic.notable_types>": return False
            domain = relation[4:-1].split(".")[0]
            if domain == "type" or domain == "common": return True
            return False

        def should_ignore(self):
            if self._filter_relation():
                return True
            return False

        def canonicalize(self, ent):
            if ent.startswith("<fb:m."):
                return "/m/" + ent[6:-1]
            elif ent.startswith("<fb:g."):
                return "/g/" + ent[6:-1]
            else:
                return ent

        def __repr__(self):
            return f"Subj: {self.subj}; Rel: {self.rel}; Obj: {self.obj}"

    def read_relations_for_question(qid, ignore_rel=True):
        infname = os.path.join(args.kg_data_path, "stagg.neighborhoods", f"{qid}.nxhd")
        if not os.path.exists(infname):
            return None
        relations = []
        with open(infname) as inf:
            for line in inf:
                rel = Relation(line)
                if ignore_rel and rel.should_ignore():
                    continue
                relations.append(rel)
        return relations

    def read_condensed_relations_for_question(qid):
        infname = os.path.join(args.kg_data_path, "condensed.stagg.neighborhoods/condensed_edges_only", f"{qid}.nxhd")
        if not os.path.exists(infname):
            return None
        relations = []
        with open(infname) as inf:
            for line in inf:
                docid, subj, rel, obj = line.strip().split('\t')
                relations.append((docid, (subj, rel, obj)))
        return relations

    def convert_relation_to_text(relation, entity_names):
        if isinstance(relation, Relation):
            subj, rel, obj = relation.subj, relation.rel, relation.obj
        else:
            subj, rel, obj = relation
        # subject
        if subj in entity_names:
            subj_surface = entity_names[subj]
        else:
            subj_surface = subj

        # object
        if obj in entity_names:
            obj_surface = entity_names[obj]
        else:
            obj_surface = obj

        # relation
        # e.g. <fb:film.film.other_crew>
        # remove bracket
        rel_surface = rel[4:-1]
        # replace '.' and '_' with ' '
        rel_surface = rel_surface.replace('.', ' ')
        # only keep the last two words
        rel_surface = ' '.join(rel_surface.split(' ')[-2:])
        rel_surface = rel_surface.replace('_', ' ')

        return subj_surface, ' '.join([subj_surface, rel_surface, obj_surface]) + '.'

    print("Loading freebase entity names...")
    ALL_ENTITY_NAME_BIN = os.path.join(args.entity_data_path, "namesTable.bin")
    entity_names = {}
    with open(ALL_ENTITY_NAME_BIN, 'rb') as inf:
        stream = BinaryStream(inf)
        dict_cnt = stream.readInt32()
        print("total entities:", dict_cnt)
        for _ in range(dict_cnt):
            key = stream.readString().decode()
            if key.startswith('m.') or key.startswith('g.'):
                key = '/' + key[0] + '/' + key[2:]
            value = stream.readString().decode()
            entity_names[key] = value
    print("Done!")

    passage_encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').to('cuda:0')
    query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').to('cuda:0')
    embedding_size = 768

    def get_context_from_kb(qid, qtext, topk):
        print("Processing", qid, "...")
        passages = []
        rel_texts = []
        relations = read_relations_for_question(qid)
        if relations is None:
            print(f"No relations found for {qid}")
        else:
            for i, rel in enumerate(relations):
                docid = f"{qid}.relations.{i}"
                title, text = convert_relation_to_text(rel, entity_names)
                passage = title + ' [SEP] ' + text
                passages.append(passage)
                rel_texts.append(text)

        cond_relations = read_condensed_relations_for_question(qid)
        if cond_relations is None:
            print(f"No condensed relations found for {qid}")
        else:
            for i, tup in enumerate(cond_relations):
                docid1 = f"{qid}.condensed_relations.{i}"
                docid, rel = tup
                assert docid == docid1
                title, text = convert_relation_to_text(rel, entity_names)
                passage = title + ' [SEP] ' + text
                passages.append(passage)
                rel_texts.append(text)

        query_embedding = query_encoder.encode(qtext)
        if len(passages) == 0:
            # no relations for this question
            print("no relations for", qid)
            return None
        else:
            print(f"{len(passages)} relations for", qid)
            if len(passages) > 20000:
                return None
            passage_embeddings = passage_encoder.encode(passages)
            index = faiss.IndexFlatIP(embedding_size)
            index.add(np.array(passage_embeddings))
            scores, idxs = index.search(np.expand_dims(query_embedding, 0), min(topk, len(passages)))
            scores = scores[0]
            idxs = idxs[0]

            topk_context = []
            for score, idx in zip(scores, idxs):
                topk_context.append((rel_texts[idx]))
            return ' '.join(topk_context)

    for split in ['train', 'test']:
        demonstration_list = []
        for question in args.webqsp_questions[split]:
            demonstration = {}
            qid = question['QuestionId']
            qtext = question['ProcessedQuestion']
            qcontext = get_context_from_kb(qid, qtext, topk=20)

            if qcontext:
                qans = ','.join([f"'{str.lower(term['EntityName'])}'" for term in question['Answers'] if term['EntityName']])
                demonstration['instruction'] = f"Use the following pieces of context to answer the users question.  If you don't know the answer, just say that you don't know, don't try to make up an answer.  \n ----------------\n {qcontext}"
                demonstration['input'] = qtext
                demonstration['output'] = f"[{qans}]"
                demonstration_list.append(demonstration)
            else:
                continue

        json.dump(demonstration_list, open(os.path.join(args.save_path, f'WebQSP_reader_{split}.json'), 'w'))

if __name__ == '__main__':
    args = parse_args()

    if args.datasets == 'WebQSP':
        args.qa_data_path = os.path.join(args.data_path, 'WebQSP/data/webqsp')
        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        args.webqsp_questions = preprocess_webqsp(args)

        # args.kg_data_path = os.path.join(args.data_path, 'WebQSP/data/freebase_2hop')
        # args.entity_data_path = os.path.join(args.data_path, 'WebQSP/data/FastRDFStore/data')
        # preprocess_webqsp_kg(args)

    elif args.datasets == 'MetaQA':

        # for hop in ['1-hop', '2-hop', '3-hop']:
        args.data_path = os.path.join(args.data_path, 'MetaQA', '3-hop')
        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        preprocess_metaqa(args)












