import json
import torch
import os

from pytorch_transformers.tokenization_bert import BertTokenizer

from formatter.Basic import BasicFormatter
from tools.load_tool import check_cache, save_cache


class BasicREFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)
        self.ent_map = {}
        with open('/data2/private/zzy/ke_exp/ke_data/emb_64/entity2id.txt', 'r') as fin:
            fin.readline()
            for line in fin:
                name, eid = line.split()
                self.ent_map[name] = int(eid)

        self.rel_map = None
        with open('/data2/private/zzy/ke_exp/datasets/fewrel/fewrel_rel2id.json', 'r') as fin:
            self.rel_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0] * len(self.embs[0])

    def process(self, data, config, mode, *args, **params):
        input = []
        label = []

        for temp in data:
            hid = temp['h']['id']
            tid = temp['t']['id']

            if hid in self.ent_map:
                hvec = self.embs[self.ent_map[hid]]
            else:
                hvec = self.zeros

            if tid in self.ent_map:
                tvec = self.embs[self.ent_map[tid]]
            else:
                tvec = self.zeros

            input.append(hvec+tvec)
            label.append(self.rel_map[temp["relation"]])

        input = torch.FloatTensor(input)
        label = torch.LongTensor(label)

        return {'input': input, 'label': label}


class CnnREFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        
        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)
        self.ent_map = {}
        with open('/data2/private/zzy/ke_exp/ke_data/emb_64/entity2id.txt', 'r') as fin:
            fin.readline()
            for line in fin:
                name, eid = line.split()
                self.ent_map[name] = int(eid)
        

        self.rel_map = None
        with open('/data2/private/zzy/ke_exp/datasets/fewrel/fewrel_rel2id.json', 'r') as fin:
            self.rel_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0] * len(self.embs[0])

        self.word2id = json.load(open('/data1/private/zzy/nre/data/pretrain/glove/word2id.json'))
        self.max_length = config.getint('data', 'max_seq_length')

    def tokenize(self, sentence, pos_head, pos_tail, padding=True):
        """
        Args:
            sentence: string, the input sentence
            pos_head: [start, end], position of the head entity
            pos_end: [start, end], position of the tail entity
        Return:
            Name of the relation of the sentence
        """
        # Sentence -> token
        tokens = sentence
        
        # Token -> index
        indexed_tokens = []
        for token in tokens:
            # Not case-sensitive
            token = token.lower()
            if token in self.word2id:
                indexed_tokens.append(self.word2id[token])
            else:
                indexed_tokens.append(self.word2id['UNK'])

        # entity embedding
        # entities = [0] * len(indexed_tokens)
        # entities[pos_head[0]] = h_id
        # entities[pos_tail[0]] = t_id

        # Position
        pos1 = []
        pos2 = []
        pos1_in_index = min(pos_head[0], self.max_length)
        pos2_in_index = min(pos_tail[0], self.max_length)
        for i in range(len(indexed_tokens)):
            pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
            pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))

        # Padding
        if padding:
            while len(indexed_tokens) < self.max_length:
                indexed_tokens.append(self.word2id['BLANK'])
            while len(pos1) < self.max_length:
                pos1.append(0)
            while len(pos2) < self.max_length:
                pos2.append(0)
            #while len(entities) < self.max_length:
            #    entities.append(0)
            indexed_tokens = indexed_tokens[:self.max_length]
            pos1 = pos1[:self.max_length]
            pos2 = pos2[:self.max_length]
            #entities = entities[:self.max_length]
        indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
        pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
        pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
        #entities = torch.tensor(entities).long().unsqueeze(0) # (1, L)

        return indexed_tokens, pos1, pos2


    def process(self, data, config, mode, *args, **params):
        tokens = []
        pos1s = []
        pos2s = []
        label = []
        ent_embs = []

        for temp in data:
            token, pos1, pos2 = self.tokenize(temp['token'], temp['h']['pos'], temp['t']['pos'])

            hid = temp['h']['id']
            tid = temp['t']['id']

            if hid in self.ent_map:
                hvec = self.embs[self.ent_map[hid]]
            else:
                hvec = self.zeros

            if tid in self.ent_map:
                tvec = self.embs[self.ent_map[tid]]
            else:
                tvec = self.zeros

            ent_embs.append(hvec+tvec)

            tokens.append(token)
            pos1s.append(pos1)
            pos2s.append(pos2)
            label.append(self.rel_map[temp["relation"]])

        token = torch.cat(tokens, 0)
        pos1 = torch.cat(pos1s, 0)
        pos2 = torch.cat(pos2s, 0)
        ent_emb = torch.FloatTensor(ent_embs)
        label = torch.LongTensor(label)

        return {'token': token, 'pos1': pos1, 'pos2': pos2, "ent_emb": ent_emb, 'label': label}


class BertREFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        
        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)
        self.ent_map = {}
        with open('/data2/private/zzy/ke_exp/ke_data/emb_64/entity2id.txt', 'r') as fin:
            fin.readline()
            for line in fin:
                name, eid = line.split()
                self.ent_map[name] = int(eid)
        

        self.rel_map = None
        with open('/data2/private/zzy/ke_exp/datasets/fewrel/fewrel_rel2id.json', 'r') as fin:
            self.rel_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0] * len(self.embs[0])

        self.max_length = config.getint('data', 'max_seq_length')
        self.tokenizer = BertTokenizer.from_pretrained(config.get("model", "bert_path"))

    def process(self, data, config, mode, *args, **params):
        input = []
        label = []
        ent_embs = []

        for temp in data:
            token, pos1, pos2 = temp['token'], temp['h']['pos'], temp['t']['pos']
            hid = temp['h']['id']
            tid = temp['t']['id']

            if hid in self.ent_map:
                hvec = self.embs[self.ent_map[hid]]
            else:
                hvec = self.zeros

            if tid in self.ent_map:
                tvec = self.embs[self.ent_map[tid]]
            else:
                tvec = self.zeros

            ent_embs.append(hvec+tvec)

            token = [x.lower() for x in token]
            if pos1[0] < pos2[0]:
                token = token[:pos1[0]] + ["*"] + token[pos1[0]:pos1[1]] + ["*"] + token[pos1[1]:pos2[0]] + ["#"] + token[pos2[0]:pos2[1]] + ["#"] + token[pos2[1]:]
            else:
                token = token[:pos2[0]] + ["#"] + token[pos2[0]:pos2[1]] + ["#"] + token[pos2[1]:pos1[0]] + ["*"] + token[pos1[0]:pos1[1]] + ["*"] + token[pos1[1]:]

            text = " ".join(token)

            token = self.tokenizer.tokenize(text)
            token = ["[CLS]"] + token
            while len(token) < self.max_length:
                token.append("[PAD]")
            token = token[:self.max_length-1]
            token = token + ["[SEP]"]
            token = self.tokenizer.convert_tokens_to_ids(token)
            input.append(token)
            label.append(self.rel_map[temp["relation"]])

        input = torch.LongTensor(input)
        label = torch.LongTensor(label)
        ent_emb = torch.FloatTensor(ent_embs)

        return {'input': input, 'label': label, 'ent_emb': ent_emb}

