import numpy as np
# from bartpy.extensions.baseestimator import ResidualBART
from sklearn.linear_model import LinearRegression
from causally.model.abstract_model import SKAbstractModel
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.meta import XGBTRegressor,LRSRegressor,XGBRRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import LinearRegression

class S_learner(SKAbstractModel):
    def __init__(self, config,dataset):

        super(S_learner, self).__init__(config,dataset)
        self.n_jobs = config['n_jobs']
        self.n_trees = config['n_trees']
        self.n_units = dataset.get_X_size()[0]

        # self.model = BaseSRegressor(learner=MLPRegressor())
        self.model = LRSRegressor()

    def calculate_loss(self, x,t,y,w):

        self.model.fit(X=x,treatment=t,y=y)

    def predict(self, x,t_0,t_1):

        _, yhat_cs, yhat_ts = self.model.predict(x,return_components=True)
        return yhat_ts[1]-yhat_cs[1]
