from __future__ import absolute_import, division, print_function

from abc import ABC
import json
import logging
import os

import torch
#from transformers import AutoModelForSequenceClassification, AutoTokenizer

import collections
import logging
import math

import numpy as np
import torch
from transformers import (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

from utils_squad import (get_predictions, read_squad_example,
                         convert_example_to_features, to_list, convert_examples_to_features, get_all_predictions)
from ts.torch_handler.base_handler import BaseHandler
import string
import timeit
#logger = logging.getLogger(__name__)


class TransformersClassifierHandler(BaseHandler, ABC):
    """
    Transformers text classifier handler class. This handler takes a text (string) and
    as input and returns the classification text based on the serialized transformers checkpoint.
    """
    def __init__(self):
        super(TransformersClassifierHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.manifest = ctx.manifest

        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        #self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
        self.RawResult = collections.namedtuple("RawResult",
                                           ["unique_id", "start_logits", "end_logits"])

        self.conjunctions = ['and', 'or', 'nor']
        self.articles = ["the", 'a', 'an', 'his', 'her', 'their', 'my', 'its', 'those', 'these', 'that', 'this', 'the']
        self.pronouns = [" He ", " She ", " he ", " she ", " they ", " them "]
        # Read model serialize/pt file
        torch.manual_seed(42)
        self.max_seq_length = 512
        self.doc_stride = 128
        self.do_lower_case = True
        self.max_query_length = 64
        self.n_best_size = 20
        self.max_answer_length = 30
        self.model, self.tokenizer = self.load_model(model_dir)
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.model.to(self.device)
        self.model.eval()

        self.initialized = True

    def load_model(self, model_path: str, do_lower_case=True):
        config = AlbertConfig.from_pretrained(model_path + "/config.json")
        tokenizer = AlbertTokenizer.from_pretrained('albert-large-v2', do_lower_case=do_lower_case)
        model = AlbertForQuestionAnswering.from_pretrained(model_path, from_tf=False, config=config)
        return model, tokenizer

    def extract_entity(self, preds, probs, idx, threshold=0.1, inv=False):#query, threshold=0.1, cutoff=0, k=10):

        entities = set()

        if preds is None:
            return []
            #return None, 0

        for pred, prob in zip(preds, probs):
            #print('---')
            t = pred
            #print(t)
            p = prob
            #print('> ', t, p)
            if len(t) < 1:
                continue
            if p > threshold and "MASK" not in t:
                #"""
                # find a more minimal candidate if possible
                for pred, prob in zip(preds, probs):
                    if t != pred and pred in t and prob > threshold and len(pred) > 2:
                        t = pred
                        p = prob
                        #print('>> ', t, p)
                        break
                #"""
                t = t.strip(string.punctuation)
                remove = t

                # take out leading articles for cleaning
                words = t.split()
                if len(words) == 0:
                    break
                if words[0].lower() in self.articles:
                    remove = " ".join(words[1:])
                    words[0] = words[0].lower()
                    t = " ".join(words[1:])
                if ',' in t:
                    t = t.split(',')
                    entities.update(t)
                    t = t[0]
                else:
                    entities.add(t)

                if 'empty' in t and inv:
                    return []

                self.input_texts[idx] = self.input_texts[idx].replace(remove, '[MASK]').replace('  ', ' ').replace(' .', '.')

        return list(entities)

    def inference(self, data):
        #data = eval(data)
        ret = []
        primers = ["Where am I located?", "What is here?", "Which objects are in my inventory?"]

        self.input_texts = []
        for idx, d in enumerate(data):
            self.input_texts.extend([str(d['state'])])
            threshold = float(d['threshold'])
            attribute = bool(d['attribute'])

        #print(self.input_text)
        locs_all = []
        objs_surr_all = []
        objs_inv_all = []
        #matches = [locs, objs_surr, objs_inv]
        #c = timeit.default_timer()

        batch_passages = []
        for txt in self.input_texts:
            batch_passages.extend([txt] * 3)

        c = timeit.default_timer()

        res = self.batch_predictTopK(batch_passages, primers * len(data), 10)
        #res = chunks(res, len(data))
        #print("@@@@@@@@@@@@@@@first@@@@@@@@@@@@@@@@@@@", timeit.default_timer() - c)


        #print('RES', res)
        #for idx, res_chunk in enumerate(res):

        for r, v in res.items():
            preds = [a['text'] for a in v]
            probs = [a['probability'] for a in v]
            if r % 3 == 0:
                locs = self.extract_entity(preds, probs, int(r/3), threshold)
                if locs != []:
                    locs = [locs[0]]
                locs_all.append(locs)
            if r % 3 == 1:
                objs_surr = self.extract_entity(preds, probs, int(r/3), threshold/10)
                objs_surr_all.append(objs_surr)
            if r % 3 == 2:
                objs_inv = self.extract_entity(preds, probs, int(r/3), threshold/10, True)
                objs_inv_all.append(objs_inv)

            #print(self.extract_entity(preds, probs))
        attributes_all = []
        lengths = []
        batch_passages = []
        primers_all = []
        objs_all = []
        ct = 0
        for objs_surr, objs_inv in zip(objs_surr_all, objs_inv_all):

            objs = objs_surr + objs_inv
            objs_all.append(objs)
            primers = ["What attribute does " + o + " have?" for o in objs]
            primers_all.extend(primers)
            batch_passages.extend([self.input_texts[ct]] * len(objs))
            lengths.append(len(objs))

            attributes = {o: [] for o in objs_surr + objs_inv}
            attributes_all.append(attributes)
            ct += 1

        if not attribute:
            # c = timeit.default_timer()
            res = self.batch_predictTopK(batch_passages, primers_all, 10)
            #print("@@@@@@@@@@@@@@@@@@@@@@@second@@@@@@@@@@@@@@@@@@@@@@", timeit.default_timer() - c)
            ct = 0
            curr_target = lengths[0]
            target_count = 0
            for r, v in res.items():
                if ct == curr_target:
                    target_count += 1
                    curr_target = lengths[target_count]
                    ct = 0
                #print(v)
                preds = [a['text'] for a in v]
                probs = [a['probability'] for a in v]
                #print(objs[r])
                attributes_all[target_count][objs_all[target_count][ct]] += self.extract_entity(preds, probs, target_count, threshold/10)
                ct += 1
        #attributes_all.append(attributes)
        #ct += 1

        for idx in range(len(data)):

            ret.append({'location': locs_all[idx], 'object_surr': objs_surr_all[idx], 'objs_inv': objs_inv_all[idx], 'attributes': attributes_all[idx]})

        return ret

    def batch_predictTopK(self, passages, questions, k, cutoff=8):
        examples = []
        id = 0
        c = timeit.default_timer()
        for passage, question in zip(passages, questions):
            example = read_squad_example(passage, question, id)
            examples.append(example)
            id += 1
        #print(examples)
        features = convert_examples_to_features(examples, self.tokenizer, self.max_seq_length, self.doc_stride,
                                               self.max_query_length, False)
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_example_index)
        #dataset.append(data)
        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=16)
        #print(len(passages))
        #print('@@@@@@@load@@@@@@@@@@@@@@@', c - timeit.default_timer())

        all_results = []
        #start_time = timeit.default_timer()
        for batch in eval_dataloader:
            c = timeit.default_timer()
            #self.model.eval()
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1]
                          }
                example_indices = batch[3]
                outputs = self.model(**inputs)

            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                result = self.RawResult(unique_id=unique_id,
                                   start_logits=to_list(outputs[0][i]),
                                   end_logits=to_list(outputs[1][i]))
                all_results.append(result)
            #print('@@@@@@@@@@pred@@@@@@@@@@@@@', c - timeit.default_timer())
        c = timeit.default_timer()
        answers = get_all_predictions(examples, features, all_results, k, self.max_answer_length, self.do_lower_case, True,
                                  cutoff)
        #print('@@@@@@@@@convert@@@@@@@@@@', c - timeit.default_timer())

        return answers

    def preprocess(self, data):
        return data

    def postprocess(self, data):
        return data



_service = TransformersClassifierHandler()


def handle(data, context):
    try:
        if not _service.initialized:
            _service.initialize(context)

        #print(data)
        if data is None:
            return None
        #data = data[0]
        size = len(data)
        data = _service.preprocess(data)
        data = _service.inference(data)
        data = _service.postprocess(data)
        data = [{'entities': data[i]} for i in range(size)]

        return data
    except Exception as e:
        raise e
