import numpy as np
from preprocessing import *
import scipy

class Faithfulness:
    def __init__(self, function, df_train, df_info, columns_to_select=None, threshold = 0.1, seed=9, nsamples=1000):
        self.function = function
        self.df_train = df_train
        self.threshold = threshold  
        self.df_info =  df_info     
        self.rng = np.random.RandomState(seed)
        self.nsamples = nsamples
        self.columns_to_select = columns_to_select
    
    
    def compute_necessity(self, X, y, feature_weights, distances = []):  
        
        ## the lower the better     
        necessity = np.zeros((len(X),))
        if len(distances) == 0:
            distances = compute_matrix_distances(X, self.df_train, self.df_info)
        i = 0
        for idx in X.index:
            weights = feature_weights.loc[idx].values
            weights = np.where(np.abs(weights) < self.threshold * np.abs(weights).max(), 0, weights)
            features = np.where(weights != 0)[0] 
            
            p = np.where(weights != 0, 1, 0)
            samples = add_noise(X.loc[idx], self.df_train, distances[i, :], self.df_info, p, 
                                    self.rng, nsamples = self.nsamples)
            if len(X) > 1:
                necessity[i] = np.mean(np.abs(self.function(samples)[:,self.columns_to_select] - y[i])) if self.columns_to_select is not None else np.mean(np.abs(self.function(samples) - y[i]))
            else:
                necessity = np.mean(np.abs(self.function(samples)[:,self.columns_to_select] - y)) if self.columns_to_select is not None else np.mean(np.abs(self.function(samples) - y))
            i += 1
            
        return necessity, distances
    
    def compute_faithfulness(self, X, y, feature_weights, distances = []):
        ## compute the harmonic mean between 1-exp(-necessity) and exp(-sufficiency)
        self.sufficiency, distances = self.compute_sufficiency(X, y, feature_weights, distances)
        self.necessity, distances = self.compute_necessity(X, y, feature_weights, distances)
        return scipy.stats.hmean([1-np.exp(-self.necessity), np.exp(-self.sufficiency)], axis=0)
    
    def compute_sufficiency(self, X, y, feature_weights, distances = []):       
        sufficiency = np.zeros((len(X),))
        if len(distances) == 0:
            distances = compute_matrix_distances(X, self.df_train, self.df_info)

        i = 0
        for idx in X.index:
            weights = feature_weights.loc[idx]
            weights = np.where(np.abs(weights) < self.threshold * np.abs(weights).max(), 0, weights)
            features = np.where(weights == 0)[0] 
            
            p = np.where(weights == 0, 1, 0)
            samples = add_noise(X.loc[idx], self.df_train, distances[i, :], self.df_info, p, 
                                    self.rng, nsamples = self.nsamples)

            if len(X) > 1:
                sufficiency[i] = np.mean(np.abs(self.function(samples)[:,self.columns_to_select] - y[i])) if self.columns_to_select is not None else np.mean(np.abs(self.function(samples) - y[i]))
            else:
                sufficiency = np.mean(np.abs(self.function(samples)[:,self.columns_to_select] - y)) if self.columns_to_select is not None else np.mean(np.abs(self.function(samples) - y))
            i += 1
            
        return sufficiency, distances
    