import econml
from joblib import Parallel, delayed
from econml.dml import DML, LinearDML, SparseLinearDML, CausalForestDML
from econml.grf import RegressionForest
import numpy as np
from sklearn.linear_model import (Lasso, LassoCV, LogisticRegression,
                                  LogisticRegressionCV,LinearRegression,
                                  MultiTaskElasticNet,MultiTaskElasticNetCV)
from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
from dnn import DNN_weights, DNN

import warnings
warnings.filterwarnings("ignore")


def run_experiment(n, n_w, model_list):

    # DGP constants
    support_size = 1
    np.random.seed(123)
    # Outcome support
    support_Y = np.random.choice(np.arange(n_w), size=support_size, replace=False)
    coefs_Y = np.random.uniform(0, 1, size=support_size)
    epsilon_sample = lambda n: np.random.normal(0, 1, size=n)
    # Treatment support
    support_T = support_Y
    coefs_T = np.random.uniform(0, 1, size=support_size)
    eta_sample = lambda n: np.random.normal(0, 1, size=n)

    def experiment(i, cv=3, model_gen=lambda: Lasso(alpha=.01)):
        np.random.seed(123 + i)
        # Generate controls, covariates, treatments and outcomes
        W = np.random.normal(0, 1, size=(n, n_w))
        T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
        Y = .5 * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
        
        est = LinearDML(model_y=model_gen(),
                        model_t=model_gen(),
                        cv=cv,
                        linear_first_stages=False,
                        random_state=123)
        est.fit(Y, T, X=None, W=W)
        inf = est.effect_inference()
        return inf.point_estimate[0], inf.stderr[0]
    
    res = {}
    for cv in [1, 2]:
        n_eff = n // cv
        w = DNN_weights(n_eff, k=1).find_weights()
        w_max = DNN_weights(n_eff, k=1).find_eps_weight()
        for model_name, model_gen in [
                                    ('dnn_rootn', lambda: DNN(w, w_max, s=int(np.ceil(n**(0.49))), sigma_e=1)),
                                    ('dnn_n34', lambda: DNN(w, w_max, s=int(np.ceil(n_eff**(3/4))), sigma_e=1)),
                                    ('dnn_n10_11', lambda: DNN(w, w_max, s=int(np.ceil(n_eff**(10/11))), sigma_e=1)),
                                    ('dnn_n', lambda: DNN(w, w_max, s=n_eff, sigma_e=1)),
                                      ('rf_rootn', lambda: RegressionForest(n_estimators=100,
                                                               min_samples_leaf=1, min_samples_split=2,
                                                               max_samples=int(np.ceil(n**(0.49))),
                                                               inference=False, subforest_size=1, honest=False,
                                                               random_state=123)),
                                      ('rf_n34', lambda: RegressionForest(n_estimators=100,
                                                               min_samples_leaf=1, min_samples_split=2,
                                                               max_samples=int(np.ceil(n_eff**(3/4))),
                                                               inference=False, subforest_size=1, honest=False,
                                                               random_state=123)),
                                      ('rf_n10_11', lambda: RegressionForest(n_estimators=100,
                                                               min_samples_leaf=1, min_samples_split=2,
                                                               max_samples=int(np.ceil(n_eff**(10/11))),
                                                               inference=False, subforest_size=1, honest=False,
                                                               random_state=123)),
                                      ('rf_n', lambda: RegressionForest(n_estimators=1,
                                                               min_samples_leaf=1, min_samples_split=2,
                                                               max_samples=n_eff,
                                                               inference=False, subforest_size=1, honest=False,
                                                               random_state=123)),
                                    ]:
            if (model_name in model_list) or (model_list is None):
                results = Parallel(n_jobs=-1, verbose=2)(delayed(experiment)(i, cv=cv,
                                                                            model_gen=model_gen)
                                                        for i in range(1000))
                res[f'cv_{cv}_{model_name}'] = results
    
    return res
