import argparse
import time
from itertools import product

from run_experiment import run_experiment


def cartesian_product(inp):
    if len(inp) == 0:
        return []
    return (dict(zip(inp.keys(), values)) for values in product(*inp.values()))

max_seed = 30
seeds = list(range(0, max_seed))
processes_to_run_in_parallel = 1

real_datasets = ['facebook_1', 'bio', 'house', 'facebook_2', 'meps_19']
corruption_types = ['noised_y', 'missing_y', 'dispersive_noised_y']
dataset_names = [f"{c}_{d}" for c in corruption_types for d in real_datasets] + ['missing_y_nslm']
real_params = {
    'main_program_name': ['regression_main'],
    'seed': seeds,
    'dataset_name': dataset_names,
    'data_type': ['real'],
    'epochs': [1000],
}

syn_params = {
    'main_program_name': ['regression_main'],
    'seed': seeds,
    'dataset_name': ['missing_y_regression_synthetic_z3', 'missing_y_pcp_fail_z3'],
    'data_type': ['synthetic'],
    'epochs': [1000],
}

params = list(cartesian_product(real_params)) + list(cartesian_product(syn_params))


processes_to_run_in_parallel = min(processes_to_run_in_parallel, len(params))
run_on_slurm = False
cpus = 2
gpus = 0
if __name__ == '__main__':

    print("jobs to do: ", len(params))
    # initializing processes_to_run_in_parallel workers
    workers = []
    jobs_finished_so_far = 0
    assert len(params) >= processes_to_run_in_parallel
    for _ in range(processes_to_run_in_parallel):
        curr_params = params.pop(0)
        main_program_name = curr_params['main_program_name']
        curr_params.pop('main_program_name')
        p = run_experiment(curr_params, main_program_name, run_on_slurm=run_on_slurm,
                           cpus=cpus, gpus=gpus)
        workers.append(p)

    # creating a new process when an old one dies
    while len(params) > 0:
        dead_workers_indexes = [i for i in range(len(workers)) if (workers[i].poll() is not None)]
        for i in dead_workers_indexes:
            worker = workers[i]
            worker.communicate()
            jobs_finished_so_far += 1
            if len(params) > 0:
                curr_params = params.pop(0)
                main_program_name = curr_params['main_program_name']
                curr_params.pop('main_program_name')
                p = run_experiment(curr_params, main_program_name, run_on_slurm=run_on_slurm,
                                   cpus=cpus, gpus=gpus)
                workers[i] = p
                if jobs_finished_so_far % processes_to_run_in_parallel == 0:
                    print(f"finished so far: {jobs_finished_so_far}, {len(params)} jobs left")
            time.sleep(10)

    # joining all last proccesses
    for worker in workers:
        worker.communicate()
        jobs_finished_so_far += 1

    print("finished all")
