from LassoOptSet import Lasso_OptSet
import matplotlib.pyplot as plt
import multiprocessing as mp
import os
from functools import partial
from util import result_analyze
import tqdm


if __name__ == '__main__':

    BASE = '/home/anonymous/data/CausallyInvariant_output/ADNIB2FAQ/FindOptSets/'
    trainsplits = list()

    for filename in os.listdir(BASE):
        if 'csv' not in filename and 'ipynb_checkpoints' not in filename:
            trainsplits.append(filename)


    params = {'lassoIter': 350, 'lassoLr': 0.05, 'lassoLam': 2,
              'regenOpt': 'SGD', 'regenBaseinit': True,
              'fSOpt': 'SGD', 'fSLr': 0.25, 'fSIter': 5000, 'fSstep': 4000, 'fSgamma': 0.4, 'fSBaseinit': True,
                              'fJLr': 0.25, 'fJIter': 2000, 'fJBaseinit': True,
              'patience': 10, 'verbose': True}

    paramsLasso_OptSet = partial(Lasso_OptSet, params=params)

    with mp.Pool(9) as pool:
        logs = list(tqdm.tqdm(pool.imap(paramsLasso_OptSet, trainsplits), total=len(trainsplits)))
        pool.close()
        pool.join()

    result_analyze(logs, trainsplits, params, save=True, date_time=None)