from experiment import experiment
from dynamic_panel_dgp import SemiSynthetic
import joblib
from joblib import Parallel, delayed


def all_exps(period_list, n_samples_list, n_exps, semi, lr):
    if semi:
        dgp = SemiSynthetic()
        dgp.create_instance()

    for n_periods in period_list:
        print(n_periods)
        for n_samples, cv in n_samples_list:
            print(n_samples)
            if not semi:
                dgp = None
            results = Parallel(n_jobs=-1, verbose=3)(delayed(experiment)(n_periods, n_samples, dgp, cv, t, semi, lr)
                                                     for t in range(n_exps))
            frames = [((results[exp][0][0],) +
                       tuple(results[exp][0][t].summary_frame(
                           alpha=0.01) for t in range(1, 6)),
                       (results[exp][1][0], results[exp][1][1].summary_frame(alpha=0.01), results[exp][1][2].summary_frame(alpha=0.01)))
                      for exp in range(n_exps)]
            joblib.dump(frames, 'test_n_periods_{}_n_samples_{}_n_exps_{}_semi_{}_lr_{}.jbl'.format(
                n_periods, n_samples, n_exps, semi, lr))


if __name__ == "__main__":
    # Synthetic Data, Low-dim, LinearRegression
    all_exps([2, 4, 8], [(2000, 3), (5000, 2), (10000, 2)], 100, False, True)
    # Synthetic Data, High-dim, Lasso
    all_exps([2, 4, 8], [(2000, 3), (5000, 2), (10000, 2)], 100, False, False)
    # Semi-Synthetic Data, Lasso
    all_exps([2], [(2000, 3), (5000, 2), (10000, 2)], 1000, True, False)
    all_exps([4], [(2000, 3), (5000, 2), (10000, 2)], 100, True, False)
    all_exps([8], [(2000, 3), (5000, 2)], 100, True, False)
