import numpy as np
import pandas as pd
from metrics import Faithfulness
import os

class FaithRej:
    
    def __init__(self, dataset, func, seed = 9, contamination = 0.1):
        self.dataset = dataset
        self.func = func
        self.seed = seed
        self.contamination = contamination
        
    def fit(self, X, Z, categorical):
        numerical = np.setdiff1d(Z.columns.values, categorical)
        df_info = {"categorical": categorical, "numerical": numerical}
        if os.path.exists(f'./explanations/{self.dataset}_faithfulness.csv'):
            df = pd.read_csv(f'./explanations/{self.dataset}_faithfulness.csv')
            df.drop_duplicates(inplace=True)
            if self.dataset != 'xG':
                df.set_index('id', inplace=True)
            else:
                df.set_index(["game_id", "action_id", "event_id"], inplace=True)
            
            df = df.loc[X.index]
            self.faithfulness = df["faithfulness"].values
        else:
            y = self.func(X)[:,1]
            self.faith = Faithfulness(self.func, X, df_info, seed = self.seed)
            self.faithfulness = self.faith.compute_faithfulness(X, y, Z)
        
        self.threshold = np.quantile(1-self.faithfulness, 1-self.contamination)
    
    def set_rejection_rate(self, contamination):
        self.contamination = contamination
        self.threshold = np.quantile(1-self.faithfulness, 1-self.contamination)
        
    def score(self, X, Z):
        if os.path.exists(f'./explanations/{self.dataset}_faithfulness.csv'):
            df = pd.read_csv(f'./explanations/{self.dataset}_faithfulness.csv')
            df = df[df['explainer'] == self.explainer]
            df.set_index('id', inplace=True)
            df = df.loc[X.index]
            self.faithfulness = df["faithfulness"].values
        else:
            y = self.func(X)[:,1]
            self.faithfulness = self.faith.compute_faithfulness(X, y, Z)
        return 1-self.faithfulness
    
    def reject(self,X, Z):
        unfaithfulness = self.score(X,Z)
        return np.where(unfaithfulness > self.threshold, 1, 0)
    
        
        