import numpy as np
from scipy.stats import rankdata
from sklearn.svm import SVR
class AmbRejClas:
    
    def __init__(self, func, nt = 1, seed = 9, contamination = 0.1):
        self.func = func
        self.nt = nt
        self.seed = seed
        self.contamination = contamination
        
        
    def fit(self, X, y=[]):
        self.probs = self.func(X)[:,1] if len(y) == 0 else y
        self.unconfidence = 1 - 2 * np.abs(self.probs - 0.5)
        self.threshold = np.quantile(self.unconfidence, 1-self.contamination)
        
    def set_rejection_rate(self, contamination): 
        self.contamination = contamination
        self.threshold = np.quantile(self.unconfidence, 1-self.contamination)  
    
    def score(self, X, y=[]):
        probs = self.func(X)[:,1] if len(y) == 0 else y
        unconfidence = 1 - 2 * np.abs(probs - 0.5)
        return unconfidence

    
    def reject(self, X, y=[]):
        return np.where(self.score(X,y) >= self.threshold, 1, 0)

class AmbRejReg:
    def __init__(self, func, seed = 9, contamination = 0.1):
        self.func = func
        self.seed = seed
        self.contamination = contamination
        self.threshold = 1 - contamination
        self.rs = np.random.RandomState(seed)
        
        
    def set_rejection_rate(self, contamination):
        self.contamination = contamination
        self.threshold = 1 - contamination
        
    def fit(self, Xtrain, ytrain, Xtrainenc, Xvalenc, model = SVR(kernel='linear',  max_iter = 10000)):
        prediction = self.func(Xtrain)
        y_res = (prediction - ytrain) ** 2
        
        self.estimator_sigma = model
        self.estimator_sigma.fit(Xtrainenc, y_res)
        val_corrected = Xvalenc + self.rs.uniform(0, 1e-10, len(Xvalenc)).reshape((len(Xvalenc), 1))
        self.sigma_pred = self.estimator_sigma.predict(val_corrected)
        
        self.ecdf = rankdata(self.sigma_pred) / len(self.sigma_pred)
        
    def score(self, X):
        Xcorr = X + self.rs.uniform(0, 1e-10, len(X)).reshape((len(X), 1))
        pred = self.estimator_sigma.predict(Xcorr)
        return np.interp(pred, np.sort(self.sigma_pred), self.ecdf)
    
    def reject(self, X):
        return np.where(self.score(X) >= self.threshold, 1, 0)
        