from FindOptSet import Find_OptSets,plot
import multiprocessing as mp
import os
import copy
import time

if __name__ == '__main__':

    BASE = '/home/anonymous/data/CausallyInvariant_output/ADNIB2FAQ/FindOptSets/'
    trainsplits = list()

    for filename in os.listdir(BASE):
        if 'csv' not in filename and 'ipynb_checkpoints' not in filename:
            trainsplits.append(filename)

    start = time.time()

    with mp.Pool(6) as pool:
        recorders = pool.map(Find_OptSets, trainsplits)

        pool.close()
        pool.join()
    #Find_OptSets(trainsplits[0])
    end = time.time()

    print('Total time cost: {:.4f}h'.format((end - start) / 3600))

    recorder = recorders[0]
    for record in recorders[1:]:
        for key in recorder:
            recorder[key]['test_errors'] += record[key]['test_errors']
            recorder[key]['h_stars'] += record[key]['h_stars']

    plot(recorder, save=True, date_time='anonymous') #