import numpy as np
import pandas as pd

class ComplRej:
    
    def __init__(self, contamination = 0.1, seed=0):
        self.contamination = contamination
        self.seed = seed
        
    def fit(self, Z):
        self.scores = self.score(Z)
        self.threshold = np.quantile(self.scores, 1-self.contamination)
    
    def set_rejection_rate(self, contamination):
        self.contamination = contamination
        self.threshold = np.quantile(self.scores, 1-self.contamination)
    
    def score(self, Z):
        fractional_contributions = np.abs(Z.values) / np.sum(np.abs(Z.values), axis=1, keepdims=True)
        entropy = fractional_contributions * np.log(fractional_contributions)
        entropy[np.isnan(entropy)] = 0
        entropy[np.isinf(entropy)] = 0
        return -np.sum(entropy, axis=1)
    
    def reject(self, Z):
        return np.where(self.score(Z) >= self.threshold, 1, 0)
    
        
        