from . import utils
from . import anchor_base
from . import anchor_explanation
import numpy as np
import json
import os
import string
import sys
from io import open
import numpy as np

def id_generator(size=15):
    """Helper function to generate random div ids. This is useful for embedding
    HTML into ipython notebooks."""
    chars = list(string.ascii_uppercase + string.digits)
    return ''.join(np.random.choice(chars, size, replace=True))

def exp_normalize(x):
    b = x.max()
    y = np.exp(x - b)
    return y / y.sum()


class AnchorTimeseries(object):
    """bla"""
    def __init__(self):
        """
        Args:
            nlp: spacy object
            class_names: list of strings
            use_unk_distribution: if True, the perturbation distribution
                will just replace words randomly with mask_string.
                If False, words will be replaced by similar words using word
                embeddings
            mask_string: String used to mask tokens if use_unk_distribution is True.
        """
        self.class_names = ['normal', 'anomaly']

    def get_sample_fn(self, text, classifier_fn, rules):
        # print(text)
        true_label = classifier_fn([text])[0]
        # positions = list(range(len(words)))
        def sample_fn(present, num_samples, compute_labels=True):
            for x in present:
                print(rules[x])
            print(num_samples)
            # print(text)
            # print(num_samples)
            print(present)
            data = np.ones((num_samples, len(rules)))
            raw = np.zeros((num_samples, len(text)))
            # print(raw)
            raw[:] = text
            def fit(data):
                sumn0 = lambda lst: sum([1 if ele != 0 else 0 for ele in lst])
                for rule in present:
                    x, y, dist = rules[rule]
                    if x!=y:
                        if not (data[x] == 1 and data[y] == 1 and sumn0(data[x:y]) >= dist):
                            return False
                    else:
                        # print(sumn0(data[0:x]))
                        if not (data[x] == 1 and sumn0(data[0:x]) >= dist):
                            return False
                return True

            def trans(data):
                ret=np.ones((len(rules)))
                sumn0 = lambda lst: sum([1 if ele != 0 else 0 for ele in lst])
                for i, rule in enumerate(rules):
                    x, y, dist = rule
                    if x!=y:
                        if not (data[x] == 1 and data[y] == 1 and sumn0(data[x:y]) >= dist):
                            ret[i] = 0
                    else:
                        # print(data[x])
                        # print(sumn0(data[0:x]))
                        if not (data[x] == 1 and sumn0(data[0:x]) >= dist):
                            ret[i] = 0
                return ret

            # perturbation model
            i = 0
            while i < num_samples:
                # change value
                n_changed = np.random.binomial(min(num_samples,2), .5)
                changed = np.random.choice(num_samples, n_changed, replace=False)
                nowdata = np.ones((len(text)))
                nowdata[changed] =2
                # delete some feature
                n_deleted = np.random.binomial(num_samples, .5)
                deleted = np.random.choice(num_samples, n_deleted, replace=False)
                nowdata[deleted] = 0
                if fit(nowdata):
                    data[i]=trans(nowdata)
                    for j in changed:
                        raw[i, j] = np.random.normal(0, 1) + raw[i, j]
                    for j in deleted:
                        raw[i, j] = None
                    i += 1
                else:
                    continue
            # for i, t in enumerate(text):
            #     if i in present:
            #         continue
            #     # n_changed = np.random.binomial(num_samples, .5)
            #     n_changed = np.random.binomial(min(num_samples,2), .5)
            #     changed = np.random.choice(num_samples, n_changed,
            #                                replace=False)
            #     # set perturbation model here
            #     for j in changed:
            #         # raw[j, i] = np.random.normal(0,1)
            #         # raw[j, i] = np.random.random()*2-1
            #         raw[j, i] = np.random.normal(0, 1)+raw[j, i]
            #         # raw[j, i] = np.random.random() * 2 - 1 + raw[j, i]
            #     data[changed, i] = 0
            raw_data = raw
            labels = []
            if compute_labels:
                labels = (classifier_fn(raw_data) == true_label).astype(int)
            labels = np.array(labels)
            # raw_data = np.array(raw_data).reshape(-1, 1)
            # print(raw_data)
            # print(raw_data)
            return raw_data, data, labels
        return text, list(range(len(text))), true_label, sample_fn

    def gen_default_rule(self, text_instance):
        return [[i, i, 0] for i in range(len(text_instance))]

    def explain_instance(self, text, classifier_fn, threshold=0.95,
                          delta=0.1, tau=0.15, batch_size=10, rules=None,onepass=False,
                          use_proba=False, beam_size=4,
                          **kwargs):
        if rules == None:
            rules = self.gen_default_rule(text)
        words, positions, true_label, sample_fn = self.get_sample_fn(
            text, classifier_fn, rules)
        # print words, true_label
        exp = anchor_base.AnchorBaseBeam.anchor_beam(
            sample_fn, n_features=len(rules), delta=delta, epsilon=tau, batch_size=batch_size,
            desired_confidence=threshold, stop_on_first=True,
            coverage_samples=1, **kwargs)
        exp['names'] = exp['feature']
        # exp['positions'] = [positions[x] for x in exp['feature']]
        exp['instance'] = text
        exp['prediction'] = true_label
        explanation = anchor_explanation.AnchorExplanation('timeseries', exp,
                                                           self.as_html)
        return explanation

    def as_html(self, exp):
        predict_proba = np.zeros(len(self.class_names))
        exp['prediction'] = int(exp['prediction'])
        predict_proba[exp['prediction']] = 1
        predict_proba = list(predict_proba)

        def jsonize(x):
            return json.dumps(x)
        this_dir, _ = os.path.split(__file__)
        bundle = open(os.path.join(this_dir, 'bundle.js'), encoding='utf8').read()
        random_id = 'top_div' + id_generator()

        example_obj = []

        def process_examples(examples, idx):
            idxs = exp['feature'][:idx + 1]
            out_dict = {}
            new_names = {'covered_true': 'coveredTrue', 'covered_false': 'coveredFalse', 'covered': 'covered'}
            for name, new in new_names.items():
                ex = [x[0] for x in examples[name]]
                out = []
                for e in ex:
                    processed = self.nlp(str(e))
                    raw_indexes = [(processed[i].text, processed[i].idx, exp['prediction']) for i in idxs]
                    out.append({'text': e, 'rawIndexes': raw_indexes})
                out_dict[new] = out
            return out_dict

        example_obj = []
        for i, examples in enumerate(exp['examples']):
            example_obj.append(process_examples(examples, i))

        explanation = {'names': exp['names'],
                       'certainties': exp['precision'] if len(exp['precision']) else [exp['all_precision']],
                       'supports': exp['coverage'],
                       'allPrecision': exp['all_precision'],
                       'examples': example_obj}
        processed = self.nlp(exp['instance'])
        raw_indexes = [(processed[i].text, processed[i].idx, exp['prediction'])
                       for i in exp['feature']]
        raw_data = {'text': exp['instance'], 'rawIndexes': raw_indexes}
        jsonize(raw_indexes)

        out = u'''<html>
        <meta http-equiv="content-type" content="text/html; charset=UTF8">
        <head><script>%s </script></head><body>''' % bundle
        out += u'''
        <div id="{random_id}" />
        <script>
            div = d3.select("#{random_id}");
            lime.RenderExplanationFrame(div,{label_names}, {predict_proba},
            {true_class}, {explanation}, {raw_data}, "text", "anchor");
        </script>'''.format(random_id=random_id,
                            label_names=jsonize(self.class_names),
                            predict_proba=jsonize(list(predict_proba)),
                            true_class=jsonize(False),
                            explanation=jsonize(explanation),
                            raw_data=jsonize(raw_data))
        out += u'</body></html>'
        return out

    def show_in_notebook(self, exp, true_class=False, predict_proba_fn=None):
        """Bla"""
        out = self.as_html(exp, true_class, predict_proba_fn)
        from IPython.core.display import display, HTML
        display(HTML(out))
