from lime import explanation
from io import open
import os
import os.path
import json
import string
import numpy as np
from sklearn.utils import check_random_state


def id_generator(size=15, random_state=None):
    """Helper function to generate random div ids. This is useful for embedding
    HTML into ipython notebooks."""
    chars = list(string.ascii_uppercase + string.digits)
    return ''.join(random_state.choice(chars, size, replace=True))

class SvsvlExplanation(explanation.Explanation):

    def __init__(self,
                 domain_mapper=None,
                 mode='classification',
                 class_names=None,
                 random_state=None):
        super().__init__(domain_mapper,
                         mode,
                         class_names,
                         random_state)
        self.explanation = {}
        self.m = {}
        self.score = {}
        self.local_pred = {}
        self.mobius = {}
        self.all_importance = {}

    def available_labels(self):
        try:
            assert self.mode == "classification"
        except AssertionError:
            raise NotImplementedError('Not supported for regression explanations.')
        else:
            ans = self.top_labels if self.top_labels else self.explanation.keys()
            return list(ans)

    def as_list(self, interacting_features=1, label=1, type=None, **kwargs):
        label_to_use = label if self.mode == "classification" else self.dummy_label
        res = filter(lambda o: len(o[0]) <= interacting_features, self.explanation[label_to_use])
        self.local_exp[label_to_use] = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)

        if type == 'full' and interacting_features == 1:
            feature_no = len(self.domain_mapper.feature_names)
            full_list = np.zeros((feature_no,))
            for t in self.local_exp[label_to_use]: full_list[t[0][0]] = t[1]
            
            return full_list
        
        ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
        ans = [(x[0], float(x[1])) for x in ans]
        return ans

    def as_pyplot_figure(self, interacting_features=1, plot_features=3, label=1, figsize=(4,4), plot_type='explanation', **kwargs):
        #res = filter(lambda o: len(o[0]) <= interacting_features, self.explanation[label])
        #self.local_exp[label] = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)[:plot_features]
        import matplotlib.pyplot as plt
        
        label_to_use = label if self.mode == "classification" else self.dummy_label
        
        res = filter(lambda o: len(o[0]) <= interacting_features, self.explanation[label_to_use])
        exp_list = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)[:plot_features]
        interaction_ans = self.domain_mapper.map_exp_ids(exp_list, **kwargs)
        exp = [(x[0], float(x[1])) for x in interaction_ans]
        title = "Explanation"
        
        if plot_type == 'interaction':
            res = filter(lambda o: len(o[0]) > 1 and len(o[0]) <= interacting_features, self.explanation[label_to_use]) #  and len(o[0]) > 1
            exp_list = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)[:plot_features]
            ans = self.domain_mapper.map_exp_ids(exp_list, **kwargs)
            exp = [(x[0], float(x[1])) for x in ans]
            title = "Interaction plot"

        if plot_type == 'mobius':
            res = filter(lambda o: len(o[0]) <= interacting_features, self.mobius[label_to_use])
            exp_list = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)[:plot_features]
            ans = self.domain_mapper.map_exp_ids(exp_list, **kwargs)
            exp = [(x[0], float(x[1])) for x in ans]
            title = "Mobius plot - Interaction effect"
        
        elif plot_type == 'importance':
            index_of_exp = list(zip(*exp_list))[0]
            index_of_imp = list(zip(*self.all_importance[label_to_use]))[0]
            indicies = [index_of_imp.index(v) for v in index_of_exp]

            exp_list_imp = [self.all_importance[label_to_use][i] for i in indicies]
            ans = self.domain_mapper.map_exp_ids(exp_list_imp, **kwargs)
            exp = [(x[0], float(x[1])) for x in ans]

            title = "Feature importance plot"

        fig = plt.figure(figsize=figsize)
        
        vals = [x[1] for x in exp]
        names = [x[0] for x in exp]
        vals.reverse()
        names.reverse()
        colors = ['green' if x > 0 else 'red' for x in vals]
        pos = np.arange(len(exp)) + .5
        plt.barh(pos, vals, align='center', color=colors)
        plt.yticks(pos, names)
        plt.yticks(rotation=70, fontsize=7, fontweight='bold')
        plt.margins(x = .1 , y = 0.1)
        # if self.mode == "classification":
        #     title = 'Local explanation for class %s' % self.class_names[label]
        # else:
        #     title = 'Local explanation'
        plt.title(title)
        return fig
        #super().as_pyplot_figure(label, **kwargs)

    def show_in_notebook(self, interacting_features=1, plot_features=3, **kwargs):
        if labels is None and self.mode == "classification":
            labels = self.available_labels()
        for label in labels:  
            res = filter(lambda o: len(o[0]) <= interacting_features, self.explanation[label])
            self.local_exp[label] = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)[:plot_features]
        super().show_in_notebook(**kwargs)
    
    def save_to_file(self, file_path, interacting_features=1, plot_features=1, labels=None, **kwargs):
        if labels is None and self.mode == "classification":
            labels = self.available_labels()
        for label in labels:  
            res = filter(lambda o: len(o[0]) <= interacting_features, self.explanation[label])
            self.local_exp[label] = sorted(res, key = lambda x: (np.abs(x[1]), len(x[0]) ), reverse=True)[:plot_features]
        super().save_to_file(file_path, **kwargs)
    