import os
from struct import *

class BinaryStream:
    def __init__(self, base_stream):
        self.base_stream = base_stream

    def readByte(self):
        return self.base_stream.read(1)

    def readBytes(self, length):
        return self.base_stream.read(length)

    def readChar(self):
        return self.unpack('b')

    def readUChar(self):
        return self.unpack('B')

    def readBool(self):
        return self.unpack('?')

    def readInt16(self):
        return self.unpack('h', 2)

    def readUInt16(self):
        return self.unpack('H', 2)

    def readInt32(self):
        return self.unpack('i', 4)

    def readUInt32(self):
        return self.unpack('I', 4)

    def readInt64(self):
        return self.unpack('q', 8)

    def readUInt64(self):
        return self.unpack('Q', 8)

    def readFloat(self):
        return self.unpack('f', 4)

    def readDouble(self):
        return self.unpack('d', 8)

    def decode_from_7bit(self):
        """
        Decode 7-bit encoded int from str data
        """
        result = 0
        index = 0
        while True:
            byte_value = self.readUChar()
            result |= (byte_value & 0x7f) << (7 * index)
            if byte_value & 0x80 == 0:
                break
            index += 1
        return result

    def readString(self):
        length = self.decode_from_7bit()
        return self.unpack(str(length) + 's', length)

    def writeBytes(self, value):
        self.base_stream.write(value)

    def writeChar(self, value):
        self.pack('c', value)

    def writeUChar(self, value):
        self.pack('C', value)

    def writeBool(self, value):
        self.pack('?', value)

    def writeInt16(self, value):
        self.pack('h', value)

    def writeUInt16(self, value):
        self.pack('H', value)

    def writeInt32(self, value):
        self.pack('i', value)

    def writeUInt32(self, value):
        self.pack('I', value)

    def writeInt64(self, value):
        self.pack('q', value)

    def writeUInt64(self, value):
        self.pack('Q', value)

    def writeFloat(self, value):
        self.pack('f', value)

    def writeDouble(self, value):
        self.pack('d', value)

    def writeString(self, value):
        length = len(value)
        self.writeUInt16(length)
        self.pack(str(length) + 's', value)

    def pack(self, fmt, data):
        return self.writeBytes(pack(fmt, data))

    def unpack(self, fmt, length=1):
        return unpack(fmt, self.readBytes(length))[0]


class Relation:
    def __init__(self, line):
        if line is None:
            self.subj = self.rel = self.obj = None
            return
        e1, rel, e2 = line.strip().split(None, 2)
        e1 = self.canonicalize(e1)
        e2 = self.canonicalize(e2)
        self.subj = e1
        self.rel = rel
        self.obj = e2

    def __hash__(self):
        return hash((self.subj, self.rel, self.obj))

    def _filter_relation(self):
        # same criteria as GraftNet
        relation = self.rel
        if relation == "<fb:common.topic.notable_types>": return False
        domain = relation[4:-1].split(".")[0]
        if domain == "type" or domain == "common": return True
        return False

    def should_ignore(self):
        if self._filter_relation():
            return True
        return False

    def canonicalize(self, ent):
        if ent.startswith("<fb:m."):
            return "/m/" + ent[6:-1]
        elif ent.startswith("<fb:g."):
            return "/g/" + ent[6:-1]
        else:
            return ent

    def __repr__(self):
        return f"Subj: {self.subj}; Rel: {self.rel}; Obj: {self.obj}"


def read_relations_for_question(qid, ignore_rel=True, kg_data_path=None):
    infname = os.path.join(kg_data_path, "stagg.neighborhoods", f"{qid}.nxhd")
    if not os.path.exists(infname):
        return None
    relations = []
    with open(infname) as inf:
        for line in inf:
            rel = Relation(line)
            if ignore_rel and rel.should_ignore():
                continue
            relations.append(rel)
    return relations


def read_condensed_relations_for_question(qid, kg_data_path):
    infname = os.path.join(kg_data_path, "condensed.stagg.neighborhoods/condensed_edges_only", f"{qid}.nxhd")
    if not os.path.exists(infname):
        return None
    relations = []
    with open(infname) as inf:
        for line in inf:
            docid, subj, rel, obj = line.strip().split('\t')
            relations.append((docid, Relation(' '.join([subj, rel, obj]))))
    return relations


def convert_relation_to_text(relation, entity_names):
    if isinstance(relation, Relation):
        subj, rel, obj = relation.subj, relation.rel, relation.obj
    else:
        subj, rel, obj = relation
    # subject
    if subj in entity_names:
        subj_surface = entity_names[subj]
    else:
        subj_surface = subj

    # object
    if obj in entity_names:
        obj_surface = entity_names[obj]
    else:
        obj_surface = obj

    # relation
    # e.g. <fb:film.film.other_crew>
    # remove bracket
    rel_surface = rel[4:-1]
    # replace '.' and '_' with ' '
    rel_surface = rel_surface.replace('.', ' ')
    # only keep the last two words
    rel_surface = ' '.join(rel_surface.split(' ')[-2:])
    rel_surface = rel_surface.replace('_', ' ')

    return subj_surface, ' '.join([subj_surface, rel_surface, obj_surface]) + '.'


def load_entities(entity_data_path):
    print("Loading freebase entity names...")
    ALL_ENTITY_NAME_BIN = os.path.join(entity_data_path, "FastRDFStore", "data", "namesTable.bin")
    entity_names = {}
    with open(ALL_ENTITY_NAME_BIN, 'rb') as inf:
        stream = BinaryStream(inf)
        dict_cnt = stream.readInt32()
        print("total entities:", dict_cnt)
        for _ in range(dict_cnt):
            key = stream.readString().decode()
            if key.startswith('m.') or key.startswith('g.'):
                key = '/' + key[0] + '/' + key[2:]
            value = stream.readString().decode()
            entity_names[key] = value
    print("Done!")
    return entity_names