import json
import torch
import os
import re
import copy
import gc
import time

from pytorch_transformers.tokenization_bert import BertTokenizer

from formatter.Basic import BasicFormatter
from tools.load_tool import check_cache, save_cache


class BasicIRFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)

        dat_path = "/data2/private/zzy/ke_exp/datasets/clueweb09/output.json"
        mark, self.objs = check_cache(dat_path)
        if mark == False:
            self.objs = []
            with open(dat_path, 'r') as fin:
                 for i, line in enumerate(fin):
                     #print(line.strip())
                     self.objs.append(json.loads(line.strip()))
            save_cache(dat_path, self.objs)

        self.mode = mode
        self.zeros = [0.] * len(self.embs[0])

    def process(self, data, config, mode, *args, **params):
        vec_query = []
        vec_doc_pos = []
        vec_doc_neg = []

        if mode == "train":
            for temp in data:
                temp = [int(x) for x in temp]
                query = self.objs[temp[0]]['eids']
                doc_pos = self.objs[temp[1]]['eids']
                doc_neg = self.objs[temp[2]]['eids']

                query_vec = [self.embs[x] for x in query[:5]]
                query_vec = query_vec + [self.zeros] * (5-len(query_vec))

                doc_pos_vec = [self.embs[x] for x in doc_pos[:50]]
                doc_pos_vec = doc_pos_vec + [self.zeros] * (50-len(doc_pos_vec))

                doc_neg_vec = [self.embs[x] for x in doc_neg[:50]]
                doc_neg_vec = doc_neg_vec + [self.zeros] * (50-len(doc_neg_vec))

                vec_query.append(query_vec)
                vec_doc_neg.append(doc_neg_vec)
                vec_doc_pos.append(doc_pos_vec)

            query = torch.tensor(vec_query)
            doc_pos = torch.tensor(vec_doc_pos)
            doc_neg = torch.tensor(vec_doc_neg)
            return {'query': query, "doc_pos": doc_pos, "doc_neg":doc_neg}
        else:
            qid = []
            did = []
            label = []
            for temp in data:
                temp = [int(x) for x in temp[:3]]
                qid.append(temp[1])
                did.append(temp[2])
                label.append(temp[0])
                query = self.objs[temp[1]]['eids']
                doc_pos = self.objs[temp[2]]['eids']
                query_vec = [self.embs[x] for x in query[:5]]
                query_vec = query_vec + [self.zeros] * (5-len(query_vec))

                doc_pos_vec = [self.embs[x] for x in doc_pos[:50]]
                doc_pos_vec = doc_pos_vec + [self.zeros] * (50-len(doc_pos_vec))

                vec_query.append(query_vec)
                vec_doc_pos.append(doc_pos_vec)
            query = torch.tensor(vec_query)
            doc_pos = torch.tensor(vec_doc_pos)
            qid = torch.LongTensor(qid)
            did = torch.LongTensor(did)
            label = torch.LongTensor(label)
            return {'query': query, "doc_pos": doc_pos, "qid": qid, "did": did, 'label': label}


class CnnIRFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)

        dat_path = "/data2/private/zzy/ke_exp/datasets/clueweb09/output.json"
        mark, self.objs = check_cache(dat_path)
        if mark == False:
            self.objs = []
            with open(dat_path, 'r') as fin:
                 for i, line in enumerate(fin):
                     #print(line.strip())
                     self.objs.append(json.loads(line.strip()))
            save_cache(dat_path, self.objs)

        #self.label_map = None
        #with open('/data2/private/zzy/ke_exp/datasets/typing/label2id.json', 'r') as fin:
        #    self.label_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0.] * len(self.embs[0])

        self.word2id = json.load(open('/data1/private/zzy/nre/data/pretrain/glove/word2id.json'))
        self.max_length = config.getint('data', 'max_seq_length')

        self.regex_drop_char = re.compile('[^a-z0-9\s]+')
        self.regex_multi_space = re.compile('\s+')

    def tokenize(self, sentence, padding=True):
        """
        Args:
            sentence: string, the input sentence
        """
        tokens = self.regex_multi_space.sub(' ', self.regex_drop_char.sub(' ', sentence.lower())).strip().split()
        
        # Token -> index
        indexed_tokens = []
        for token in tokens:
            # Not case-sensitive
            if token in self.word2id:
                indexed_tokens.append(self.word2id[token])
            else:
                indexed_tokens.append(self.word2id['UNK'])

        # Padding
        if padding:
            while len(indexed_tokens) < self.max_length:
                indexed_tokens.append(self.word2id['BLANK'])
            indexed_tokens = indexed_tokens[:self.max_length]

        indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)

        return indexed_tokens


    def process(self, data, config, mode, *args, **params):
        vec_query = []
        vec_doc_pos = []
        vec_doc_neg = []
        vec_query_ent = []
        vec_doc_pos_ent = []
        vec_doc_neg_ent = []
        #ent_embs = []

        if mode == "train":
            for temp in data:
                temp = [int(x) for x in temp]
                query = self.tokenize(self.objs[temp[0]]['text'])
                doc_pos = self.tokenize(self.objs[temp[1]]['text'])
                doc_neg = self.tokenize(self.objs[temp[2]]['text'])

                vec_query.append(query)
                vec_doc_neg.append(doc_neg)
                vec_doc_pos.append(doc_pos)
                
                query = self.objs[temp[0]]['eids']
                doc_pos = self.objs[temp[1]]['eids']
                doc_neg = self.objs[temp[2]]['eids']

                query_vec = [self.embs[x] for x in query[:5]]
                query_vec = query_vec + [self.zeros] * (5-len(query_vec))

                doc_pos_vec = [self.embs[x] for x in doc_pos[:50]]
                doc_pos_vec = doc_pos_vec + [self.zeros] * (50-len(doc_pos_vec))

                doc_neg_vec = [self.embs[x] for x in doc_neg[:50]]
                doc_neg_vec = doc_neg_vec + [self.zeros] * (50-len(doc_neg_vec))

                vec_query_ent.append(query_vec)
                vec_doc_neg_ent.append(doc_neg_vec)
                vec_doc_pos_ent.append(doc_pos_vec)

            query = torch.cat(vec_query, 0)
            doc_pos = torch.cat(vec_doc_pos, 0)
            doc_neg = torch.cat(vec_doc_neg, 0)
            query_ent = torch.tensor(vec_query_ent)
            doc_pos_ent = torch.tensor(vec_doc_pos_ent)
            doc_neg_ent = torch.tensor(vec_doc_neg_ent)
            return {'query': query, "doc_pos": doc_pos, "doc_neg":doc_neg, 
                    'query_ent': query_ent, 'doc_pos_ent': doc_pos_ent
                    , 'doc_neg_ent': doc_neg_ent}
        else:
            qid = []
            did = []
            label = []
            for temp in data:
                temp = [int(x) for x in temp[:3]]
                qid.append(temp[1])
                did.append(temp[2])
                label.append(temp[0])
                query = self.tokenize(self.objs[temp[1]]['text'])
                doc_pos = self.tokenize(self.objs[temp[2]]['text'])
                vec_query.append(query)
                vec_doc_pos.append(doc_pos)

                query = self.objs[temp[1]]['eids']
                doc_pos = self.objs[temp[2]]['eids']
                query_vec = [self.embs[x] for x in query[:5]]
                query_vec = query_vec + [self.zeros] * (5-len(query_vec))

                doc_pos_vec = [self.embs[x] for x in doc_pos[:50]]
                doc_pos_vec = doc_pos_vec + [self.zeros] * (50-len(doc_pos_vec))

                vec_query_ent.append(query_vec)
                vec_doc_pos_ent.append(doc_pos_vec)


            query = torch.cat(vec_query, 0)
            doc_pos = torch.cat(vec_doc_pos, 0)
            qid = torch.LongTensor(qid)
            did = torch.LongTensor(did)
            label = torch.LongTensor(label)
            query_ent = torch.tensor(vec_query_ent)
            doc_pos_ent = torch.tensor(vec_doc_pos_ent)
            return {'query': query, "doc_pos": doc_pos, "qid": qid, "did": did, 'query_ent': query_ent, 'doc_pos_ent': doc_pos_ent, 'label': label}

class BertIRFormatter(BasicFormatter):
    def __init__(self, config, mode, *args, **params):
        super().__init__(config, mode, *args, **params)

         
        dat_path = "/data2/private/zzy/ke_exp/datasets/clueweb09/output.json"
        mark, self.objs = check_cache(dat_path)
        if mark == False:
            self.objs = []
            with open(dat_path, 'r') as fin:
                 for i, line in enumerate(fin):
                     #print(line.strip())
                     self.objs.append(json.loads(line.strip()))
            save_cache(dat_path, self.objs)

        ke_path = config.get("data", "ke_path")
        mark, self.embs = check_cache(ke_path)
        if mark == False:
            self.embs = []
            with open(ke_path, 'r') as fin:
                for i, line in enumerate(fin):
                    self.embs.append([float(x) for x in line.split()])
            save_cache(ke_path, self.embs)
        #
        #self.label_map = None
        #with open('/data2/private/zzy/ke_exp/datasets/typing/label2id.json', 'r') as fin:
        #    self.label_map = json.loads(fin.read())

        self.mode = mode
        self.zeros = [0.] * len(self.embs[0])

        self.max_length = config.getint('data', 'max_seq_length')
        self.tokenizer = BertTokenizer.from_pretrained(config.get("model", "bert_path"))

    def process(self, data, config, mode, *args, **params):
        pos_input = []
        neg_input = []
        vec_query_ent = []
        vec_doc_pos_ent = []
        vec_doc_neg_ent = []
        #ent_embs = []

        if mode == "train":
            for temp in data:
                temp = [int(x) for x in temp]
                query = self.tokenizer.tokenize(self.objs[1]['text'])
                doc_pos = self.tokenizer.tokenize(self.objs[-1]['text'])
                doc_neg = self.tokenizer.tokenize(self.objs[temp[2]]['text'])

                pos_token = ["[CLS]"] + query + ["[SEP]"] + doc_pos + ["[SEP]"]


                neg_token = ["[CLS]"] + query + ["[SEP]"] + doc_neg + ["[SEP]"]
                pos_token += ["[PAD]"] * (self.max_length-len(pos_token))
                neg_token += ["[PAD]"] * (self.max_length-len(neg_token))
                pos_token = pos_token[:self.max_length]
                neg_token = neg_token[:self.max_length]

                print(pos_token)
                exit(0)

                pos_input.append(self.tokenizer.convert_tokens_to_ids(pos_token))
                neg_input.append(self.tokenizer.convert_tokens_to_ids(neg_token))
                
                query = self.objs[temp[0]]['eids']
                doc_pos = self.objs[temp[1]]['eids']
                doc_neg = self.objs[temp[2]]['eids']

                query_vec = [self.embs[x] for x in query[:5]]
                query_vec = query_vec + [self.zeros] * (5-len(query_vec))

                doc_pos_vec = [self.embs[x] for x in doc_pos[:50]]
                doc_pos_vec = doc_pos_vec + [self.zeros] * (50-len(doc_pos_vec))

                doc_neg_vec = [self.embs[x] for x in doc_neg[:50]]
                doc_neg_vec = doc_neg_vec + [self.zeros] * (50-len(doc_neg_vec))

                vec_query_ent.append(query_vec)
                vec_doc_neg_ent.append(doc_neg_vec)
                vec_doc_pos_ent.append(doc_pos_vec)

            pos_input = torch.LongTensor(pos_input)
            neg_input = torch.LongTensor(neg_input)

            query_ent = torch.tensor(vec_query_ent)
            doc_pos_ent = torch.tensor(vec_doc_pos_ent)
            doc_neg_ent = torch.tensor(vec_doc_neg_ent)
            return {'pos_input': pos_input, 'neg_input': neg_input,
                    'query_ent': query_ent, 'doc_pos_ent': doc_pos_ent
                    , 'doc_neg_ent': doc_neg_ent}
        else:
            qid = []
            did = []
            label = []
            for temp in data:
                temp = [int(x) for x in temp[:3]]
                qid.append(temp[1])
                did.append(temp[2])
                label.append(temp[0])
                query = self.tokenizer.tokenize(self.objs[temp[1]]['text'])
                doc_pos = self.tokenizer.tokenize(self.objs[temp[2]]['text'])
                pos_token = ["[CLS]"] + query + ["[SEP]"] + doc_pos + ["[SEP]"]
                pos_token += ["[PAD]"] * (self.max_length-len(pos_token))
                pos_token = pos_token[:self.max_length]
                pos_input.append(self.tokenizer.convert_tokens_to_ids(pos_token))

                query = self.objs[temp[1]]['eids']
                doc_pos = self.objs[temp[2]]['eids']
                query_vec = [self.embs[x] for x in query[:5]]
                query_vec = query_vec + [self.zeros] * (5-len(query_vec))

                doc_pos_vec = [self.embs[x] for x in doc_pos[:50]]
                doc_pos_vec = doc_pos_vec + [self.zeros] * (50-len(doc_pos_vec))

                vec_query_ent.append(query_vec)
                vec_doc_pos_ent.append(doc_pos_vec)


            pos_input = torch.LongTensor(pos_input)
            qid = torch.LongTensor(qid)
            did = torch.LongTensor(did)
            label = torch.LongTensor(label)
            query_ent = torch.tensor(vec_query_ent)
            doc_pos_ent = torch.tensor(vec_doc_pos_ent)
            return {"pos_input": pos_input, "qid": qid, "did": did, 'query_ent': query_ent, 'doc_pos_ent': doc_pos_ent, 'label': label}


