import numpy as np
import pandas as pd

class StabRej:
    
    def __init__(self, dataset, seed = 9, explainer='KernelSHAP',contamination = 0.1):
        self.dataset = dataset
        self.seed = seed
        self.contamination = contamination
        
    def fit(self, Z):
        indexes = Z.index
        if self.dataset != 'xG':
            df_consistency = pd.read_csv(f'./csvFiles/explanations/{self.dataset}_stability.csv')
            n_col = 4
        else:
            df_consistency = pd.read_csv(f'./csvFiles/explanations/{self.dataset}.csv')
            df_consistency = df_consistency[df_consistency['seed'] != 9]
            df_consistency.set_index(['game_id', 'action_id','event_id'], inplace=True)
            n_col = 3
        self.correlations = np.zeros(len(indexes))
        for i, index in enumerate(indexes):
            z = Z.loc[index].values
            expls = df_consistency.loc[df_consistency['id'] == index].iloc[:, n_col:] if self.mode != 'xG' else df_consistency.loc[index].iloc[:, n_col:]
            
            correlation = 0
            for idx, expl in expls.iterrows():
                correlation += np.corrcoef(z.reshape((len(z),)), expl.values.reshape((len(expl),)))[0,1]
            self.correlations[i] = correlation / len(expls)
        
        self.train_scores = -0.5*self.correlations + 0.5
        self.threshold = np.quantile(self.train_scores, 1-self.contamination)
    
    def set_rejection_rate(self, contamination):
        self.contamination = contamination
        self.threshold = np.quantile(self.train_scores, 1-self.contamination)
    
    def score(self, Z):
        indexes = Z.index
        if self.mode != 'xG':
            df_consistency = pd.read_csv(f'./explanations/{self.dataset}_stability.csv')
        else:
            df_consistency = pd.read_csv(f'./explanations/{self.dataset}.csv')
            df_consistency = df_consistency[df_consistency['seed'] != 9]
            df_consistency.set_index(['game_id', 'action_id','event_id'], inplace=True)
        correlations = np.zeros(len(indexes))
        for i, index in enumerate(indexes):
            
            z = Z.loc[index].values
            expls = df_consistency.loc[df_consistency['id'] == index].iloc[:, 1:] if self.mode != 'xG' else df_consistency.loc[index].iloc[:, n_col:]
            correlation = 0
            for idx, expl in expls.iterrows():
                correlation += np.corrcoef(z.reshape((len(z),)), expl.values.reshape((len(expl),)))[0,1]
            
            correlations[i] = correlation / len(expls)
        
        correlations = -0.5*correlations + 0.5
        return np.nan_to_num(correlations, nan=1)
    
    def reject(self, Z):
        consistency = self.score(Z)
        return np.where(consistency >= self.threshold, 1, 0)
    
        
        