import numpy as np
import pandas as pd 
from sklearn.svm import SVC


class ULER:
    
    def __init__(self, kernel ='rbf', C=1, k=10, eps=0.1, seed=9, contamination=0.1):
        self.seed = seed
        self.kernel = kernel
        self.C = C
        self.k = k
        self.eps = eps
        self.contamination = contamination
        self.threshold = 1 - self.contamination
        
    def augment(self, Z, y, features):
    
        std = np.std(Z, axis = 0)*self.eps
        any_feature_bad = np.sum(features, axis = 1)
        Ztoaug = Z.iloc[np.logical_and(y == 1, any_feature_bad > 0),:]
        ytoaug = y[np.logical_and(y == 1, any_feature_bad > 0)]
        features_toaug = features[np.logical_and(y == 1, any_feature_bad > 0),:]
        Znot = Z.iloc[~np.logical_and(y == 1, any_feature_bad > 0),:]
        ynot = y[~np.logical_and(y == 1, any_feature_bad > 0)]
            
        Zaug = np.repeat(Ztoaug, self.k, axis = 0)
        for i in range(Ztoaug.shape[0]):
            for f in range(Z.shape[1]):
                if features_toaug[i,f] == 0 and ytoaug.iloc[i] == 1:
                    Zaug[(i*self.k+1):(i*self.k+self.k), f] += np.random.normal(loc = 0, scale = std[f], size = self.k-1)
                    
                elif features_toaug[i,f] == 1 and ytoaug.iloc[i] == 0:
                    Zaug[(i*self.k+1):(i*self.k+self.k), f] += np.random.normal(loc = 0, scale = std[f], size = self.k-1)  
                    
        Zaug = np.concatenate((Zaug, Znot), axis = 0)
        yaug = np.concatenate((np.repeat(ytoaug, self.k), ynot))
        
        return Zaug, yaug
        
    def fit(self, Z, y, features):
        Zaug, yaug = self.augment(Z, y, features)
        self.rejector = SVC(C=self.C, kernel=self.kernel, class_weight='balanced', random_state=self.seed, max_iter=100000)
        train_scores = self.rejector.fit(Zaug, yaug).decision_function(Zaug)
        self.threshold = np.quantile(train_scores, 1-self.contamination)
    
    def set_rejection_rate(self, rr):
        self.contamination = rr
        self.threshold = np.quantile(self.train_scores, 1-self.contamination)
    
    def reject(self, Z):
        return np.where(self.score(Z) > self.threshold, 1, 0)
    
    def score(self, Z):
        return self.rejector.decision_function(Z)