# Main imports
from econml.orf import DROrthoForest,DMLOrthoForest
# Helper imports
import numpy as np

from sklearn.linear_model import Lasso, LogisticRegression
from causally.model.abstract_model import SKAbstractModel
from sklearn.ensemble import RandomForestRegressor

class RF(SKAbstractModel):
    def __init__(self, config,dataset):
        super(RF, self).__init__(config,dataset)
        self.n_units = dataset.get_X_size()[0]
        self.n_trees = config['n_trees']
        self.min_leaf_size = config['min_leaf_size']
        self.max_depth = config['max_depth']
        self.solver = config['solver']
        self.subsample_ratio = config['subsample_ratio']
        self.lambda_reg = config['lambda_reg']
        self.lambda_reg  = np.sqrt(np.log(30) / (10 * self.subsample_ratio * self.n_units))

        self.model = DROrthoForest(
            n_trees = self.n_trees,
            max_depth = self.max_depth,
            lambda_reg = self.lambda_reg,
            min_leaf_size= self.min_leaf_size,
        )


    def calculate_loss(self, x,t,y,w):


        self.model.fit(Y=y,T=t,X=x,W=w)


    def predict(self, x,t_0,t_1):

        # y_tau = self.model.predict(np.concatenate([t_0.reshape(-1,1),x],axis=1)) - self.model.predict(np.concatenate([t_1.reshape(-1,1),x],axis=1))
        y_tau = self.model.effect(X=x, T0=t_0, T1=t_1)
        return y_tau