from httpx import delete
from torch import unsafe_chunk
from . import utils
from . import anchor_base
from . import anchor_explanation
import numpy as np
import json
import os
import string
import sys
from io import open
import numpy as np
import time

def id_generator(size=15):
    """Helper function to generate random div ids. This is useful for embedding
    HTML into ipython notebooks."""
    chars = list(string.ascii_uppercase + string.digits)
    return ''.join(np.random.choice(chars, size, replace=True))


def exp_normalize(x):
    b = x.max()
    y = np.exp(x - b)
    return y / y.sum()


class TextGenerator(object):
    def __init__(self, url=None):
        from transformers import DistilBertTokenizer, DistilBertForMaskedLM
        import torch
        self.torch = torch
        self.url = url
        if url is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
            self.bert = DistilBertForMaskedLM.from_pretrained('distilbert-base-cased')
            self.bert.to(self.device)
            self.bert.eval()

    def unmask(self, text_with_mask):
        torch = self.torch
        tokenizer = self.bert_tokenizer
        model = self.bert
        encoded = np.array(tokenizer.encode(text_with_mask, add_special_tokens=True))
        input_ids = torch.tensor(encoded)
        masked = (input_ids == self.bert_tokenizer.mask_token_id).numpy().nonzero()[0]
        to_pred = torch.tensor([encoded], device=self.device)
        with torch.no_grad():
            outputs = model(to_pred)[0]
        ret = []
        for i in masked:
            v, top_preds = torch.topk(outputs[0, i], 500)
            words = tokenizer.convert_ids_to_tokens(top_preds)
            v = np.array([float(x) for x in v])
            ret.append((words, v))
        return ret


class SentencePerturber:
    def __init__(self, words, tg, onepass=False):
        self.tg = tg
        self.words = words
        self.cache = {}
        self.mask = self.tg.bert_tokenizer.mask_token
        self.array = np.array(words, '|U80')
        self.onepass = onepass
        self.pr = np.zeros(len(self.words))
        for i in range(len(words)):
            a = self.array.copy()
            a[i] = self.mask
            s = ' '.join(a)
            w, p = self.probs(s)[0]
            self.pr[i] = min(0.5, dict(zip(w, p)).get(words[i], 0.01))

    def sample(self, data):
        a = self.array.copy()
        masks = np.where(data == 0)[0]
        a[data == 0] = self.mask
        if self.onepass:
            s = ' '.join(a)
            rs = self.probs(s)
            reps = [np.random.choice(a, p=p) for a, p in rs]
            a[masks] = reps
        else:
            for i in masks:
                s = ' '.join(a)
                words, probs = self.probs(s)[0]
                a[i] = np.random.choice(words, p=probs)
        return a

    def probs(self, s):
        if s not in self.cache:
            r = self.tg.unmask(s)
            self.cache[s] = [(a, exp_normalize(b)) for a, b in r]
            if not self.onepass:
                self.cache[s] = self.cache[s][:1]
        return self.cache[s]

    # def perturb_sentence(self, n, prob_change=0.5):
    #     raw = np.zeros((n, len(self.words)), '|U80')
    #     data = np.ones((n, len(self.words)))


class AnchorText(object):
    """bla"""

    def __init__(self, nlp, class_names, use_unk_distribution=True, mask_string='UNK', switch_pi = 1, midist = True, **kwargs):
        """
        Args:
            nlp: spacy object
            class_names: list of strings
            use_unk_distribution: if True, the perturbation distribution
                will just replace words randomly with mask_string.
                If False, words will be replaced by similar words using word
                embeddings
            mask_string: String used to mask tokens if use_unk_distribution is True.
        """
        self.nlp = nlp
        self.class_names = class_names
        self.use_unk_distribution = use_unk_distribution
        self.tg = None
        self.mask_string = mask_string
        self.switch_pi = switch_pi
        self.midist =  midist
        self.sample_tm = 0
        if not self.use_unk_distribution:
            self.tg = TextGenerator()

    def get_sample_fn(self, text, classifier_fn, onepass=False, use_proba=False, rules=[]) -> (np.ndarray, np.ndarray, int, callable):
        rules = np.array(rules)
        true_label = classifier_fn([text])[0]
        processed = self.nlp(text)
        words = np.array([x.text for x in processed], dtype='|U80')
        positions = [x.idx for x in processed]
        perturber = None
        if not self.use_unk_distribution:
            perturber = SentencePerturber(words, self.tg, onepass=onepass)
        def my_sample_fn_3(present,num_samples,compute_labels=True):
            def fit(pos):
                data = np.zeros((num_samples, len(rules)))
                for idx,(x,y,d) in enumerate(rules):
                    if d==-1:
                        d = -len(words)
                    if x!=y:
                        data[(pos[:,x]>=0) & (pos[:,y]>=0) & (pos[:,y]-pos[:,x]>d), idx] = 1
                    else:
                        data[pos[:,y]>=d, idx] = 1
                return data
                    


            b = time.time()
            pos = np.ones((num_samples, len(words)),dtype=int) # finally will represnet the position of the j-th word in the i-th instance
            rpos = np.zeros((num_samples,len(words)),dtype=int) # reverse projection of pos, the j-th position of the i-th instance is word[rpos[i][j]]
            rpos[:] = -2
            # replace
            fixed = []
            for x,y,d in rules[present]:
                fixed += [x,y]
            fixed = list(set(fixed))
            for i in range(len(words)):
                if i in fixed:
                    continue
                if self.use_unk_distribution:
                    n_changed = np.random.binomial(num_samples, .5)
                    changed = np.random.choice(num_samples, n_changed, replace=False)
                    pos[changed, i] = 0 # 0 -> UNK
                else:
                    raise NotImplementedError

            # deletion
            for x,y,d in rules[present]:
                if d>0:
                    if x!=y:
                        fixed += list(range(x+1,y))
                    else:
                        fixed += list(range(x))
                
            fixed = list(set(fixed))
            for i in range(len(words)):
                if i in fixed:
                    continue
                n_changed = np.random.binomial(num_samples, .5)
                changed = np.random.choice(num_samples, n_changed, replace=False)
                pos[changed, i] = -1  # delete
            
            cpos = pos.copy() # covienent to switch
            cnt = np.zeros(num_samples, dtype=int)
            for i in range(0,len(words)):
                deleted = (pos[:,i]==-1)
                empty = (pos[:,i]==0)
                rem = (pos[:, i]==1)
                
                pos[rem,i] = cnt[rem]
                pos[empty,i] = -1
                pos[deleted,i] = -1
                cpos[rem,i] = cnt[rem]
                cpos[empty,i] = cnt[empty]
                cpos[deleted,i] = -1

                rpos[rem,cnt[rem]] = i
                rpos[empty,cnt[empty]] = -1
                
                cnt[empty] += 1
                cnt[rem] += 1
                
            # print(pos)

            # switch
            fixed = []
            for x,y,d in rules[present]:
                if y!=x and y-x==d+1:
                    fixed += [x,y-1]
                elif x==y and d == x:
                    fixed += [x-1]
            
            if self.switch_pi == 1:
                switched_idx = np.random.random(num_samples)<0.5
                switched_idx = np.where(switched_idx)[0]
                non_fixed = list(set(range(len(words)-1))-set(fixed))+[len(words)-1]
                switched_pos = np.random.choice(non_fixed,len(switched_idx),replace=True)
                switched_idx = switched_idx[switched_pos!=len(words)-1]
                switched_pos = switched_pos[switched_pos!=len(words)-1]
                # print(switched_idx.dtype,switched_pos.dtype,cpos.dtype)

                deteled = (cpos[switched_idx,switched_pos]==-1) | (cpos[switched_idx,switched_pos+1]==-1)
                switched_pos = switched_pos[~deteled]
                switched_idx = switched_idx[~deteled]

                rpos[switched_idx,cpos[switched_idx,switched_pos]], rpos[switched_idx,cpos[switched_idx,switched_pos+1]] = rpos[switched_idx,cpos[switched_idx,switched_pos+1]], rpos[switched_idx,cpos[switched_idx,switched_pos]]
                rem = (pos[switched_idx,switched_pos]!=-1)
                pos[switched_idx[rem],switched_pos[rem]] = cpos[switched_idx[rem],switched_pos[rem]+1]
                rem = (pos[switched_idx,switched_pos+1]!=-1)
                pos[switched_idx[rem],switched_pos[rem]+1] = cpos[switched_idx[rem],switched_pos[rem]]
            else:
                raise NotImplementedError
            
            if self.use_unk_distribution:
                raw_data = np.zeros((num_samples, len(words)), '|U80')
                raw_data[rpos==-1] = self.mask_string
                raw_data[rpos>=0] = words[rpos[rpos>=0]]
                raw_data = np.array([' '.join(x).strip() for x in raw_data])
                data = fit(pos)
            else:
                raise NotImplementedError
            self.sample_tm += (time.time() - b)

            labels = []
            if compute_labels:
                b = time.time()
                labels = (classifier_fn(raw_data) == true_label).astype(int)
                # print("predict_time: %s" % (time.time() - b))
                # print(labels)
            labels = np.array(labels)
            raw_data = raw_data.reshape(-1, 1)
            if not data[:,present].all():
                print(pos)
                # print(rules)
                print(raw_data)
                print(present)
                print(rules[present])
                print(data)
                print(labels)
                fit(pos)
                raise AssertionError
            # print(pos)
            return raw_data, data, labels


        def my_sample_fn(present, num_samples, compute_labels=True):
            b = time.time()
            def fit(data):
                # rule_list = rules[nrule_idx]
                countn0 = lambda a:sum([1 if x != -1 else 0 for x in a])
                new_data = np.ones(len(rules))
                for i, rule in enumerate(rules):
                    x,y,d = rule
                    if d == -1: # if d is -1, the distance is not limited
                        d = -len(words)
                    if x != y:
                        new_data[i] = 1 if data[x] > 0 and data[y] > 0 and data[y]-data[x] > d else 0
                    else:
                        new_data[i] = 1 if data[y] > d else 0
                return new_data

            remain_num = num_samples
            raw_data = []
            ret_data = []
 
            # generate randomly and filter
            # debugcnt = 0

            neg_rate = np.random.random()

            while remain_num > 0:

                data = np.ones((remain_num, len(words)))
                # data = np.array(list(range(len(words)))*remain_num).reshape(remain_num, len(words))
                # pos = data.copy()
                pos = np.zeros((remain_num,len(words)), dtype=int)
                raw = np.zeros((remain_num, len(words)), '|U80')
                raw[:] = words
                for i, t in enumerate(words):
                    flag = False
                    for x in present:
                        if i in rules[x][0:2]:
                            flag = True
                    if flag:
                        continue
                    #choose some place bacome UNK
                    n_changed = np.random.binomial(remain_num, .5)
                    changed = np.random.choice(remain_num, n_changed, replace=False)
                    data[changed, i] = 0
                    if self.use_unk_distribution:
                        raw[changed, i] = self.mask_string
                if not self.use_unk_distribution:
                    raw=[]
                    for i, d in enumerate(data):
                        r = perturber.sample(d)
                        data[i] = r == words
                        # print("data[i]")
                        # print(data[i])
                        raw.append(r)
                    raw=np.array(raw)

                # delete someplace

                for i, t in enumerate(words):
                    n_changed = np.random.binomial(remain_num, .5)
                    changed = np.random.choice(remain_num, n_changed, replace=False)
                    if n_changed != 0:
                        raw[changed, i] = ''
                        data[changed, i] = -1
                
                
                cnt = np.zeros(remain_num, dtype=int)
                for i in range(len(words)):
                    non_empty = (data[:,i]!=-1)
                    rem = (data[:, i]>0)
                    pos[non_empty,cnt[non_empty]] = i  # the original word at this place, even it's replaced
                    cnt[non_empty] += 1
                    data[rem,i] = cnt[rem]
                
                # switch two words, only two continous words
                n_convert = np.random.binomial(remain_num,neg_rate)
                coverted = np.random.choice(remain_num,n_convert,replace=False)
                for i in coverted:
                    if cnt[i]<=1:
                        continue
                    x = np.random.choice(cnt[i]-1)
                    d = data[i]
                    p = pos[i]
                    d[p[x]],d[p[x+1]] = d[p[x+1]],d[p[x]]
                    p[x], p[x+1] = p[x+1], p[x]
                    
                for x, y in zip(data, raw):
                    temp_data =  fit(x)
                    if sum(temp_data[present])==len(present):
                        remain_num -= 1
                        raw_data.append(' '.join(y))
                        ret_data.append(temp_data)

                # debugcnt += 1
                # if debugcnt > 100000:
                #     print("WTF")
                #     print("remain_num: "+str(remain_num))
                #     print(raw)
                #     raise TimeoutError
            labels = []
            tm = (time.time() - b)
            self.sample_tm += tm
            # print("sample_time: %s" % tm)
            if compute_labels:
                b = time.time()
                labels = (classifier_fn(raw_data) == true_label).astype(int)
                # print("predict_time: %s" % (time.time() - b))
                # print(labels)
            labels = np.array(labels)
            max_len = max([len(x) for x in raw_data])
            dtype = '|U%d' % (max(80, max_len))
            raw_data = np.array(raw_data, dtype).reshape(-1, 1)
            ret_data = np.vstack(ret_data)
            # print("sampled")
            # print(raw_data,flush=True)
            return raw_data, ret_data, labels

        def sample_fn(present, num_samples, compute_labels=True):
            if self.use_unk_distribution:
                data = np.ones((num_samples, len(words)))
                raw = np.zeros((num_samples, len(words)), '|U80')
                raw[:] = words
                for i, t in enumerate(words):
                    if i in present:
                        continue
                    n_changed = np.random.binomial(num_samples, .5)
                    changed = np.random.choice(num_samples, n_changed,
                                               replace=False)
                    raw[changed, i] = self.mask_string
                    data[changed, i] = 0
                raw_data = [' '.join(x) for x in raw]
            else:
                data = np.zeros((num_samples, len(words)))
                for i in range(len(words)):
                    if i in present:
                        continue
                    probs = [1 - perturber.pr[i], perturber.pr[i]]
                    data[:, i] = np.random.choice([0, 1], num_samples, p=probs)
                data[:, present] = 1
                raw_data = []
                for i, d in enumerate(data):
                    r = perturber.sample(d)
                    data[i] = r == words
                    raw_data.append(' '.join(r))
            labels = []
            if compute_labels:
                labels = (classifier_fn(raw_data) == true_label).astype(int)
            labels = np.array(labels)
            max_len = max([len(x) for x in raw_data])
            dtype = '|U%d' % (max(80, max_len))
            raw_data = np.array(raw_data, dtype).reshape(-1, 1)
            # print(raw_data)
            return raw_data, data, labels
        if self.midist:
            return words, positions, true_label, my_sample_fn
        else:
            return words, positions, true_label, my_sample_fn_3

    def explain_instance(self, text, classifier_fn, rules=[], threshold=0.95,
                         delta=0.1, tau=0.15, batch_size=10, onepass=False,
                         use_proba=False, beam_size=4, coverage_samples = 1000,
                         **kwargs):
        # print("explain: ", end="")
        # print(rules)
        if type(text) == bytes:
            text = text.decode()
        words, positions, true_label, sample_fn = self.get_sample_fn(
            text, classifier_fn, onepass=onepass, use_proba=use_proba, rules=rules)
        # print words, true_label
        exp = anchor_base.AnchorBaseBeam.anchor_beam(
            sample_fn, n_features=len(rules), delta=delta, epsilon=tau, batch_size=batch_size,
            desired_confidence=threshold, stop_on_first=True,
            coverage_samples=coverage_samples,beam_size=beam_size, **kwargs)
        exp['names'] = [(words[rules[x][0]],words[rules[x][1]],rules[x][2]) for x in exp['feature']]
        # exp['positions'] = [positions[x] for x in exp['feature']]
        exp['instance'] = text
        exp['prediction'] = true_label
        explanation = anchor_explanation.AnchorExplanation('text', exp,
                                                           self.as_html)
        return explanation

    def explain_instance_with_distance(self, text, classifier_fn):
        if type(text) == bytes:
            text = text.decode()

    def as_html(self, exp):
        predict_proba = np.zeros(len(self.class_names))
        exp['prediction'] = int(exp['prediction'])
        predict_proba[exp['prediction']] = 1
        predict_proba = list(predict_proba)

        def jsonize(x):
            return json.dumps(x)

        this_dir, _ = os.path.split(__file__)
        bundle = open(os.path.join(this_dir, 'bundle.js'), encoding='utf8').read()
        random_id = 'top_div' + id_generator()

        example_obj = []

        def process_examples(examples, idx):
            idxs = exp['feature'][:idx + 1]
            out_dict = {}
            new_names = {'covered_true': 'coveredTrue', 'covered_false': 'coveredFalse', 'covered': 'covered'}
            for name, new in new_names.items():
                ex = [x[0] for x in examples[name]]
                out = []
                for e in ex:
                    processed = self.nlp(str(e))
                    raw_indexes = [(processed[i].text, processed[i].idx, exp['prediction']) for i in idxs]
                    out.append({'text': e, 'rawIndexes': raw_indexes})
                out_dict[new] = out
            return out_dict

        example_obj = []
        for i, examples in enumerate(exp['examples']):
            example_obj.append(process_examples(examples, i))

        explanation = {'names': exp['names'],
                       'certainties': exp['precision'] if len(exp['precision']) else [exp['all_precision']],
                       'supports': exp['coverage'],
                       'allPrecision': exp['all_precision'],
                       'examples': example_obj}
        processed = self.nlp(exp['instance'])
        raw_indexes = [(processed[i].text, processed[i].idx, exp['prediction'])
                       for i in exp['feature']]
        raw_data = {'text': exp['instance'], 'rawIndexes': raw_indexes}
        jsonize(raw_indexes)

        out = u'''<html>
        <meta http-equiv="content-type" content="text/html; charset=UTF8">
        <head><script>%s </script></head><body>''' % bundle
        out += u'''
        <div id="{random_id}" />
        <script>
            div = d3.select("#{random_id}");
            lime.RenderExplanationFrame(div,{label_names}, {predict_proba},
            {true_class}, {explanation}, {raw_data}, "text", "anchor");
        </script>'''.format(random_id=random_id,
                            label_names=jsonize(self.class_names),
                            predict_proba=jsonize(list(predict_proba)),
                            true_class=jsonize(False),
                            explanation=jsonize(explanation),
                            raw_data=jsonize(raw_data))
        out += u'</body></html>'
        return out

    def show_in_notebook(self, exp, true_class=False, predict_proba_fn=None):
        """Bla"""
        out = self.as_html(exp, true_class, predict_proba_fn)
        from IPython.core.display import display, HTML
        display(HTML(out))
