import functools
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from aif360.sklearn.inprocessing import GridSearchReduction
from .utils import sk_model


class FairReg:
    def __init__(self, params):
        self.params = params
        self.debias_model = None

    def debias(self, x, train = True):
        hparams = self.params
        y = x['target']
        s = x['sensitive']
        X = x.drop(['target'], axis=1)
        base_learner = sk_model('linear_reg')

        if train:
            grid_search_red = GridSearchReduction(prot_attr="sensitive",
                                      estimator= base_learner,
                                      constraints="BoundedGroupLoss",
                                      loss="Square",
                                      min_val=y.min(),
                                      max_val=y.max(),
                                      grid_size=100,
                                      drop_prot_attr=True,
                                      constraint_weight=1 - self.params )
            

            grid_search_red.fit(X, y)
            self.debias_model = grid_search_red
            ypred = grid_search_red.predict(X)
        else:
            ypred = self.debias_model.predict(X)
        
        return ypred
        

