# Main imports
from econml.orf import DROrthoForest,DMLOrthoForest

from econml.sklearn_extensions.linear_model import  WeightedLasso
from causalml.inference.tree import CausalRandomForestRegressor
# Helper imports
import numpy as np

from sklearn.linear_model import Lasso, LogisticRegression
from causally.model.abstract_model import SKAbstractModel
from sklearn.ensemble import RandomForestRegressor

class RF(SKAbstractModel):
    def __init__(self, config,dataset):
        super(RF, self).__init__(config,dataset)
        self.n_units = dataset.get_X_size()[0]
        self.n_trees = config['n_trees']
        self.min_leaf_size = config['min_leaf_size']
        self.max_depth = config['max_depth']
        self.solver = config['solver']
        self.subsample_ratio = config['subsample_ratio']
        self.lambda_reg = config['lambda_reg']
        self.lambda_reg  = np.sqrt(np.log(30) / (10 * self.subsample_ratio * self.n_units))

        self.model = CausalRandomForestRegressor(
            n_estimators = self.n_trees,
            criterion="causal_mse",
            alpha = 0.05,
            max_depth = self.max_depth,
            min_samples_split= 2,
            min_samples_leaf= 100,
            min_weight_fraction_leaf= 0.0001,
            max_features = 1.0,
            max_leaf_nodes = 100

        )


    def calculate_loss(self, x,t,y,w):

        from sklearn.preprocessing import MinMaxScaler
        x = MinMaxScaler().fit_transform(x)
        self.model.fit(x,t,y)


    def predict(self, x,t_0,t_1):

        # y_tau = self.model.predict(np.concatenate([t_0.reshape(-1,1),x],axis=1)) - self.model.predict(np.concatenate([t_1.reshape(-1,1),x],axis=1))
        from sklearn.preprocessing import MinMaxScaler
        x = MinMaxScaler().fit_transform(x)
        y_tau = self.model.predict(x)
        return y_tau