import numpy as np
from .utils import sk_model, sample_celoss
from config import cfg

class Vanilla:
    # for binary classification problems with binary sensitive attributes.
    def __init__(self):
        self.debias_model = None


    def debias(self, x, train = True):
        y = x['target'].values
        # regression
        if y.dtype == np.float32:
            X = x.drop(['target', 'sensitive'], axis=1).values
            model = sk_model('linear_reg')

        # classification
        elif y.dtype == np.int64:
            X = x.drop(['target'], axis=1).values
            model = sk_model('linear')
        if train:
            model.fit(X, y)
            self.debias_model = model
            predict = model.predict(X)
        else:
            predict = self.debias_model.predict(X)
            if cfg.get('setting') and cfg.get('setting').get('metric_name') == 'loss':
                prob = self.debias_model.predict_proba(X).astype(np.float32)
                if prob.shape[1] == 1:
                    prob = np.hstack((1 - prob, prob))
                loss = sample_celoss(prob, y.ravel())
                return loss.astype(np.float32)
        return predict