import functools
import itertools
from multiprocessing.pool import Pool

import papermill as pm

import results


def run_notebook(basename, source_notebook, run, learning_rate):
    copied_filename = results.full_path(basename, f'{basename}_large_relu.{run}.{learning_rate}.ipynb')

    pm.execute_notebook(
        source_notebook,
        copied_filename,
        parameters=dict(
            basename=basename,
            run=run,
            learning_rate=learning_rate,
        )
    )


def main():
    basename = '2_layer_large_relu'
    source_notebook = 'difficulty_sweep_2_layer_large_relu.ipynb'
    results.ensure_results_dir(basename)

    # Seeding is done by run id, each run id induces a deterministic notebook.
    # So if you want more runs after a first set, just use a list with different run_ids.
    # (Old results will be stomped, currently.)
    runs = list(range(30))
    learning_rates = [1e-2, 5e-3, 1e-3, 5e-4, 1e-4, 5e-5, 1e-5, 5e-6]

    runner = functools.partial(run_notebook, basename, source_notebook)
    args = itertools.product(runs, learning_rates)
    with Pool(2) as p:
        print(p.starmap(runner, args, chunksize=1))


if __name__ == '__main__':
    main()
