import numpy as np
# from bartpy.extensions.baseestimator import ResidualBART
from sklearn.linear_model import LinearRegression
from causally.model.abstract_model import SKAbstractModel
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor,LRSRegressor,XGBRRegressor

class T_learner(SKAbstractModel):
    def __init__(self, config,dataset):

        super(T_learner, self).__init__(config,dataset)
        self.n_jobs = config['n_jobs']
        self.n_trees = config['n_trees']
        self.n_units = dataset.get_X_size()[0]

        self.model = LRSRegressor()

    def calculate_loss(self, x,t,y,w):

        self.model.fit(X=x,treatment=t,y=y)

    def predict(self, x,t_0,t_1):

        # _, yhat_cs, yhat_ts = self.model.predict(x)
        _, yhat_cs, yhat_ts = self.model.predict(x, return_components=True)
        return yhat_ts[1]-yhat_cs[1]
