import os
import time

from multiprocessing.pool import Pool
from functools import partial

import papermill as pm
import psycopg2 as pg
from psycopg2.extras import DictCursor


def run_notebook(sweep_name, storage_args, notebook_output_dir, source_notebook, run_id):
    connection = pg.connect(**storage_args)
    cursor = connection.cursor(cursor_factory=DictCursor)
    cursor.execute('select * from runs where run_id = %s limit 1;', (run_id,))
    
    notebook_args = dict(cursor.fetchone())
    
    notebook_args = {**notebook_args, **storage_args}

    results_notebook_filename = os.path.join(notebook_output_dir, f'run_{run_id}.ipynb')
    
    start = time.time()
    pm.execute_notebook(
        source_notebook,
        results_notebook_filename,
        parameters=notebook_args,
        progress_bar=False,
    )
    end = time.time()

    cursor.execute('select count(*) from runs where sweep_name = %s and complete = \'false\'', (sweep_name,))
    remaining_runs = cursor.fetchone()[0]
    
    connection.close()
    print(f'Run {run_id} complete in {end - start : .2f}s. {remaining_runs} runs remaining.')


def run_sweep(sweep_name, source_notebook, notebook_output_dir, storage_args, num_threads):
    if not os.path.exists(notebook_output_dir):
        os.makedirs(notebook_output_dir)
    
    print('Connecting to database...')
    connection = pg.connect(**storage_args)
    cursor = connection.cursor()

    print('Querying for incomplete runs...')
    cursor.execute('select run_id from runs where sweep_name = %s and complete = \'false\'', (sweep_name,))
    incomplete_run_ids = [row[0] for row in cursor.fetchall()]
    connection.close()

    print(f'{len(incomplete_run_ids)} remaining to execute.')

    runner = partial(run_notebook, sweep_name, storage_args, notebook_output_dir, source_notebook)
    with Pool(num_threads) as threadpool:
        threadpool.map(runner, incomplete_run_ids)