from __future__ import absolute_import, division, print_function

from bert import load_model

# import json
# import os
# import numpy as np
# import tensorflow as tf
# import matplotlib
# import subprocess

# matplotlib.use('Agg')

# import networkx as nx
# import json
# import argparse
# import matplotlib.pyplot as plt
# import random
import string
import timeit

import collections
import logging
import math

import numpy as np
import torch
from transformers import (WEIGHTS_NAME, BertConfig,
                          BertForQuestionAnswering, BertTokenizer,
                          XLMConfig, XLMForQuestionAnswering,
                          XLMTokenizer, XLNetConfig,
                          XLNetForQuestionAnswering,
                          XLNetTokenizer,
                          DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
                          AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

from utils_squad import (get_predictions, read_squad_example,
                         convert_example_to_features, to_list, convert_examples_to_features, get_all_predictions)

import timeit

RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])
conjunctions = ['and', 'or', 'nor']
articles = ["the", 'a', 'an', 'his', 'her', 'their', 'my', 'its', 'those', 'these', 'that', 'this', 'the']
pronouns = [" He ", " She ", " he ", " she ", " they ", " them "]




# model_path = '/home/mhmh/KG-A2C-chained/kga2c/askbert/models'
model_path = '/KG-A2C-chained/kga2c/askbert/models'
torch.manual_seed(42)
max_seq_length = 512
doc_stride = 128
do_lower_case = True
max_query_length = 64
n_best_size = 20
max_answer_length = 30
model, tokenizer = load_model(model_path)
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
model.to(device)
model.eval()

def predict(passage: str, question: str):
    example = read_squad_example(passage, question, 0)
    features = convert_example_to_features(example, tokenizer, max_seq_length, doc_stride, max_query_length, False)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                            all_example_index)
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=1)
    all_results = []
    for batch in eval_dataloader:
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2]
                      }
            example_indices = batch[3]
            outputs = model(**inputs)

        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            result = RawResult(unique_id=unique_id,
                               start_logits=to_list(outputs[0][i]),
                               end_logits=to_list(outputs[1][i]))
            all_results.append(result)
    answer = get_predictions(example, features, all_results, n_best_size, max_answer_length, do_lower_case, True, 100)
    return answer

def predictTopK(self, passage: str, question: str, k: int, cutoff=8):
    #c = timeit.default_timer()
    example = read_squad_example(passage, question, 0)
    features = convert_example_to_features(example, tokenizer, max_seq_length, doc_stride, max_query_length, False)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                            all_example_index)
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=1)
    #print('load', c - timeit.default_timer())
    c = timeit.default_timer()

    all_results = []
    for batch in eval_dataloader:
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2]
                      }
            example_indices = batch[3]
            outputs = model(**inputs)

        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            result = RawResult(unique_id=unique_id,
                               start_logits=to_list(outputs[0][i]),
                               end_logits=to_list(outputs[1][i]))
            all_results.append(result)
    #print('pred', c - timeit.default_timer())
    #c = timeit.default_timer()

    answers = get_predictions(example, features, all_results, k, max_answer_length, do_lower_case, True, cutoff)
    #print('convert', c - timeit.default_timer())

    return answers

def batch_predictTopK(self, passages, questions, k, cutoff=8):
    #c = timeit.default_timer()
    examples = []
    id = 0
    for passage, question in zip(passages, questions):
        example = read_squad_example(passage, question, id)
        examples.append(example)
        id += 1
    #print(examples)
    features = convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride,
                                           max_query_length, False)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                            all_example_index)
    #dataset.append(data)
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=16)
    #print('load', c - timeit.default_timer())
    c = timeit.default_timer()

    all_results = []
    start_time = timeit.default_timer()
    for batch in eval_dataloader:
        #model.eval()
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1]
                      }
            example_indices = batch[3]
            outputs = model(**inputs)

        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            result = RawResult(unique_id=unique_id,
                               start_logits=to_list(outputs[0][i]),
                               end_logits=to_list(outputs[1][i]))
            all_results.append(result)
    #print('pred', c - timeit.default_timer())
    #c = timeit.default_timer()
    answers = get_all_predictions(examples, features, all_results, k, max_answer_length, do_lower_case, True,
                              cutoff)
    #print('convert', c - timeit.default_timer())

    return answers

def extract_entity(input_text, preds, probs, threshold=0.1, inv=False):#query, threshold=0.1, cutoff=0, k=10):

    entities = set()

    if preds is None:
        return []
        #return None, 0

    for pred, prob in zip(preds, probs):
        #print('---')
        t = pred
        #print(t)
        p = prob
        #print('> ', t, p)
        if len(t) < 1:
            continue
        if p > threshold and "MASK" not in t:
            #"""
            # find a more minimal candidate if possible
            for pred, prob in zip(preds, probs):
                if t != pred and pred in t and prob > threshold and len(pred) > 2:
                    t = pred
                    p = prob
                    #print('>> ', t, p)
                    break
            #"""
            t = t.strip(string.punctuation)
            remove = t

            # take out leading articles for cleaning
            words = t.split()
            if len(words) == 0:
                break
            if words[0].lower() in articles:
                remove = " ".join(words[1:])
                words[0] = words[0].lower()
                t = " ".join(words[1:])
            # print(remove)
            if ',' in t:
                t = t.split(',')
                entities.update(t)
                t = t[0]
            else:
                entities.add(t)

            if 'empty' in t and inv:
                return []

            input_text = input_text.replace(remove, '[MASK]').replace('  ', ' ').replace(' .', '.')
            #return t, p
        # else:
        # find a more minimal candidate if possible
        # for pred, prob in zip(preds, probs):
        #     if prob > threshold and "MASK" not in pred.text and len(pred.text) > 2 and pred.text in t:
        #         t = pred.text.strip(string.punctuation)
        #         p = prob
        #         input_text = input_text.replace(t, '[MASK]').replace('  ', ' ').replace(' .', '.')
        #         print(t, p)
        #         return t, p

    #return None, 0
    #print(entities)
    return list(entities)

def generate(input_text, threshold=0.3, attribute=False):
    input_text = input_text
    locs = []
    objs_surr = []
    objs_inv = []
    primers = ["Where am I located?", "What is here?", "Which objects are in my inventory?"]
    #matches = [locs, objs_surr, objs_inv]
    c = timeit.default_timer()

    res = batch_predictTopK(model, [input_text] * 3, primers, 10)

    #print(res)

    for r, v in res.items():
        preds = [a['text'] for a in v]
        probs = [a['probability'] for a in v]
        if r == 0:
            locs = extract_entity(input_text, preds, probs, threshold)
        if r == 1:
            objs_surr = extract_entity(input_text, preds, probs, threshold/10)
        if r == 2:
            objs_inv = extract_entity(input_text, preds, probs, threshold/10, True)

        #print(extract_entity(preds, probs))

    objs = objs_surr + objs_inv
    primers = ["What attribute does " + o + " have?" for o in objs]
    attributes = {o: [] for o in objs_surr + objs_inv}
    if not attribute:
        res = batch_predictTopK(model, [input_text] * len(objs), primers, 10)

        for r, v in res.items():
            #print(v)
            preds = [a['text'] for a in v]
            probs = [a['probability'] for a in v]
            #print(objs[r])
            attributes[objs[r]] += extract_entity(input_text, preds, probs, threshold/10)

    return {'location': locs, 'object_surr': objs_surr, 'objs_inv': objs_inv, 'attributes': attributes}


if __name__ == "__main__":
    askbert_args = {'input_text': '', 'length': 10, 'batch_size': 1, 'temperature': 1, 'model_name': '117M',
                    'seed': 0, 'nsamples': 10, 'cutoffs': "6 7 5", 'write_sfdp': False, 'random': False}
    #world = generate(askbert_args)

    sent = "[loc]  undying garden barren branches and pines suggest winter's quiet. metal and wood, painted and sculpted, the false garden is heedless of harvest, drought or rain.  a jade compass rests at the garden's metaphysical center, between translucence and opacity, curve and line, and other dichotomies too subtle for your eye. a gentle warmth from the surface invites touch.  transactions of light show other rooms southward and westward.  you can see a tortoise shell comb here. [inv] you are carrying:   a bronze hilted dagger   half a porcelain mask   armor and silks  worn  [obs] the shapes of light you saw as you entered fade irrevocably.   undying garden barren branches and pines suggest winter's quiet. metal and wood, painted and sculpted, the false garden is heedless of harvest, drought or rain.  a jade compass rests at the garden's metaphysical center, between translucence and opacity, curve and line, and other dichotomies too subtle for your eye. a gentle warmth from the surface invites touch.  transactions of light show other rooms southward and westward.  you can see a tortoise shell comb here. [atr] animate, clothing, concealed, door, enterable, light, openable, proper, scenery, static, supporter, transparent, pluralname"
    sents = [sent] * 16

    from glob import glob
    import multiprocessing
    from functools import partial

    multiprocessing.set_start_method('forkserver')
    manager = multiprocessing.Manager()  # create SyncManager
    matches = manager.list()  # create a shared list here
    link_matches = partial(generate, matches)  # create one arg callable to
    # pass to pool.map()
    pool = multiprocessing.Pool(processes=1)

    # exit()
    d = timeit.default_timer()
    pool.map(generate, sents)  # apply partial to files list
    pool.close()
    pool.join()
    #print("total", timeit.default_timer() - d)