import json
import torch
import os

from pytorch_transformers.tokenization_bert import BertTokenizer

from formatter.Basic import BasicFormatter
from tools.load_tool import check_cache, save_cache


class BasicETFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)

        self.label_map = None
        with open('/data2/private/zzy/ke_exp/datasets/typing/label2id.json', 'r') as fin:
            self.label_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0] * len(self.embs[0])

    def process(self, data, config, mode, *args, **params):
        input = []
        label = []

        for temp in data:
            eid = int(temp['id'])
            input.append(self.embs[eid])
            label_tmp = [0] * len(self.label_map)
            for l in temp["labels"]:
                label_tmp[self.label_map[l]] = 1
            label.append(label_tmp)

        input = torch.FloatTensor(input)
        label = torch.LongTensor(label)

        return {'input': input, 'label': label}


class CnnETFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)

        self.label_map = None
        with open('/data2/private/zzy/ke_exp/datasets/typing/label2id.json', 'r') as fin:
            self.label_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0] * len(self.embs[0])

        self.word2id = json.load(open('/data1/private/zzy/nre/data/pretrain/glove/word2id.json'))
        self.max_length = config.getint('data', 'max_seq_length')

    def tokenize(self, sentence, pos_head, padding=True):
        """
        Args:
            sentence: string, the input sentence
            pos_head: [start, end], position of the head entity
            pos_end: [start, end], position of the tail entity
        Return:
            Name of the relation of the sentence
        """
        # Sentence -> token
        tokens = sentence
        
        # Token -> index
        indexed_tokens = []
        for token in tokens:
            # Not case-sensitive
            token = token.lower()
            if token in self.word2id:
                indexed_tokens.append(self.word2id[token])
            else:
                indexed_tokens.append(self.word2id['UNK'])

        # entity embedding
        # entities = [0] * len(indexed_tokens)
        # entities[pos_head[0]] = h_id
        # entities[pos_tail[0]] = t_id

        # Position
        pos1 = []
        pos1_in_index = min(pos_head[0], self.max_length)
        pos_head[1] = pos_head[1] - 1
        for i in range(len(indexed_tokens)):
            if (i - pos_head[0]) * (i - pos_head[1]) <= 0:
                pos1.append(min(self.max_length, 2 * self.max_length - 1))
            elif (i - pos_head[0]) > 0:
                pos1.append(min(i - min(pos_head[1], self.max_length) + self.max_length, 2 * self.max_length - 1))
            else:
                pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))

        # Padding
        if padding:
            while len(indexed_tokens) < self.max_length:
                indexed_tokens.append(self.word2id['BLANK'])
            while len(pos1) < self.max_length:
                pos1.append(0)
            #while len(entities) < self.max_length:
            #    entities.append(0)
            indexed_tokens = indexed_tokens[:self.max_length]
            pos1 = pos1[:self.max_length]
            #entities = entities[:self.max_length]

        indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
        pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
        #entities = torch.tensor(entities).long().unsqueeze(0) # (1, L)

        return indexed_tokens, pos1


    def process(self, data, config, mode, *args, **params):
        tokens = []
        pos1s = []
        label = []
        ent_embs = []

        for temp in data:
            text = temp['sent']
            mention = text[temp['start']:temp['end']]
            token = text.split()
            mention_len = len(mention.split())
            begin = -1
            for i in range(len(token)):
                cand = ' '.join(token[i:i+mention_len])
                if cand == mention:
                    begin = i
            assert begin != -1
            end = begin + mention_len
            token, pos1 = self.tokenize(token, [begin, end])

            eid = int(temp['id'])
            evec = self.embs[eid]

            ent_embs.append(evec)

            tokens.append(token)
            pos1s.append(pos1)
            label_tmp = [0] * len(self.label_map)
            for l in temp["labels"]:
                label_tmp[self.label_map[l]] = 1
            label.append(label_tmp)

        token = torch.cat(tokens, 0)
        pos1 = torch.cat(pos1s, 0)
        ent_emb = torch.FloatTensor(ent_embs)
        label = torch.LongTensor(label)

        return {'token': token, 'pos1': pos1, "ent_emb": ent_emb, 'label': label}


class BertETFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        
        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)
        
        self.label_map = None
        with open('/data2/private/zzy/ke_exp/datasets/typing/label2id.json', 'r') as fin:
            self.label_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0] * len(self.embs[0])

        self.max_length = config.getint('data', 'max_seq_length')
        self.tokenizer = BertTokenizer.from_pretrained(config.get("model", "bert_path"))

    def process(self, data, config, mode, *args, **params):
        input = []
        label = []
        ent_embs = []

        for temp in data:
            text = temp['sent']
            mention = text[temp['start']:temp['end']]
            
            eid = int(temp['id'])
            evec = self.embs[eid]

            ent_embs.append(evec)

            text = text.replace(mention, "# "+ mention + " #", 1)

            token = self.tokenizer.tokenize(text)
            token = ["[CLS]"] + token
            while len(token) < self.max_length:
                token.append("[PAD]")
            token = token[:self.max_length-1]
            token = token + ["[SEP]"]
            token = self.tokenizer.convert_tokens_to_ids(token)
            input.append(token)

            label_tmp = [0] * len(self.label_map)
            for l in temp["labels"]:
                label_tmp[self.label_map[l]] = 1
            label.append(label_tmp)

        input = torch.LongTensor(input)
        label = torch.LongTensor(label)
        ent_emb = torch.FloatTensor(ent_embs)

        return {'input': input, 'label': label, 'ent_emb': ent_emb}

