import numpy as np
import pandas as pd 
from sklearn.svm import SVC


class ULER_YZ:
    
    def __init__(self, kernel ='rbf', C=1, k=10, eps=0.1, seed=9, contamination=0.1):
        self.seed = seed
        self.kernel = kernel
        self.C = C
        self.k = k
        self.eps = eps
        self.contamination = contamination
        self.threshold = 1 - self.contamination
        
    def augment(self, ypred, Z, y, features):
    
        y = y.values if isinstance(y, pd.Series) else y
        std = np.ones(Z.shape[1])*self.eps
        any_feature_bad = np.sum(features, axis = 1)
        Ztoaug = Z.iloc[np.logical_and(y == 1, any_feature_bad > 0),:]
        ytoaug = y[np.logical_and(y == 1, any_feature_bad > 0)]
        ypredtoaug = ypred[np.logical_and(y == 1, any_feature_bad > 0)]
        features_toaug = features[np.logical_and(y == 1, any_feature_bad > 0),:]
        Znot = Z.iloc[~np.logical_and(y == 1, any_feature_bad > 0),:]
        yprednot = ypred[~np.logical_and(y == 1, any_feature_bad > 0)]
        ynot = y[~np.logical_and(y == 1, any_feature_bad > 0)]
            
        Zaug = np.repeat(Ztoaug, self.k, axis = 0)
        for i in range(Ztoaug.shape[0]):
            for f in range(Z.shape[1]):
                if features_toaug[i,f] == 0 and ytoaug[i] == 1:
                    Zaug[(i*self.k+1):(i*self.k+self.k), f] += np.random.normal(loc = 0, scale = std[f], size = self.k-1)
                    
                elif features_toaug[i,f] == 1 and ytoaug[i] == 0:
                    Zaug[(i*self.k+1):(i*self.k+self.k), f] += np.random.normal(loc = 0, scale = std[f], size = self.k-1)  
                    
        Zaug = np.concatenate((Zaug, Znot), axis = 0)
        yaug = np.concatenate((np.repeat(ytoaug, self.k), ynot))
        ypredaug = np.concatenate((np.repeat(ypredtoaug, self.k), yprednot))
        
        return ypredaug, Zaug, yaug
        
    def fit(self, ypred, Z, y, features):
        ypredaug, Zaug, yaug = self.augment(ypred, Z, y, features)            
        training = np.concatenate((ypredaug.reshape(-1,1), Zaug), axis=1)
        self.rejector = SVC(C=self.C, kernel=self.kernel, class_weight='balanced', random_state=self.seed, max_iter=100000)
        train_scores = self.rejector.fit(training, yaug).decision_function(training)
        self.threshold = np.quantile(train_scores, 1-self.contamination)
    
    def set_rejection_rate(self, rr):
        self.contamination = rr
        self.threshold = np.quantile(self.train_scores, 1-self.contamination)
    
    def reject(self, ypred, Z):
        return np.where(self.score(ypred,Z) > self.threshold, 1, 0)
    
    def score(self, ypred, Z):
        xz = np.concatenate((ypred.reshape(-1,1), Z), axis=1)
        return self.rejector.decision_function(xz)