import argparse
import time
from itertools import product

from run_experiment import run_experiment

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=-1, help='')
parser.add_argument('--data_type', type=str, default='real', help='')
args = parser.parse_args()


def cartesian_product(inp):
    if len(inp) == 0:
        return []
    return (dict(zip(inp.keys(), values)) for values in product(*inp.values()))



processes_to_run_in_parallel = 1
min_seed = 0
max_seed = 9
n_seeds = max_seed - min_seed
count = 5
seeds = list(range(min_seed, max_seed))
if args.seed != -1:
    seeds = list(range(min_seed+args.seed * (n_seeds // count), min_seed+(args.seed+1) * (n_seeds // count)))

if args.seed < n_seeds % count:
    seeds += [min_seed + (n_seeds//count ) * count + args.seed ]

real_datasets = ['facebook_1', 'bio', 'house', 'facebook_2', 'meps_19']
# real_datasets = ['facebook_1', 'bio', 'house',  ]
# corruption_types = ['noised_x', 'noised_y', 'missing_x', 'missing_y', 'dispersive_noised_y']
# corruption_types = ['missing_y', 'noised_y', 'dispersive_noised_y']
corruption_types = ['missing_y',]
dataset_names = [f"{c}_{d}" for c in corruption_types for d in real_datasets] #+ ['missing_y_nslm',]
dataset_names_e =  [f"e_{e}_{c}_{d}" for c in corruption_types for d in real_datasets for e in [1, 25, 50, 100, 500, 1000, 5000, 10000, 50000]]# + ['missing_y_nslm',]
# dataset_names = ['missing_y_nslm',]
real_params = {
    'main_program_name': ['regression_main'],
    'seed': seeds,
    'dataset_name':  dataset_names+dataset_names_e,  #  [f"{c}_{d}" for c in corruption_types for d in real_datasets]
    'data_type': ['real'],
    'epochs': [1000],
}

syn_params = {
    'main_program_name': ['regression_main'],
    'seed': seeds,
    # 'dataset_name': ['missing_y_regression_synthetic', 'missing_y_regression_synthetic_z3', 'missing_y_regression_synthetic_with_overcoverage_z3'], #[f'regression_synthetic_z{i}' for i in range(1, 6)],  missing_y_synthetic_causal
    'dataset_name': ['missing_y_pcp_fail_z3'], #[f'regression_synthetic_z{i}' for i in range(1, 6)],  missing_y_synthetic_causal
    'data_type': ['synthetic'],
    'epochs': [1000],
    # 'beta': [0.005, 0.002, 0.001, 0.0001, 0.00001, 0.000001, 0.0075, 0.01, 0.025, 0.05, 0.075, 0.085, 0.09, 0.095, 0.099, 0.0999, 0.09999],
}

if 'real' in args.data_type.lower():
    params = real_params
else:
    params = syn_params


# params = list(cartesian_product(real_params)) + list(cartesian_product(cifar10_params)) + list(cartesian_product(cifar10c_params))
params = list(cartesian_product(real_params)) #+ list(cartesian_product(syn_params))
# params = list(cartesian_product(params))

# params = list(cartesian_product(ihdp_params))

processes_to_run_in_parallel = min(processes_to_run_in_parallel, len(params))
run_on_slurm = False
cpus = 2
gpus = 0
if __name__ == '__main__':

    print("jobs to do: ", len(params))
    # initializing processes_to_run_in_parallel workers
    workers = []
    jobs_finished_so_far = 0
    assert len(params) >= processes_to_run_in_parallel
    for _ in range(processes_to_run_in_parallel):
        curr_params = params.pop(0)
        main_program_name = curr_params['main_program_name']
        curr_params.pop('main_program_name')
        p = run_experiment(curr_params, main_program_name, run_on_slurm=run_on_slurm,
                           cpus=cpus, gpus=gpus)
        workers.append(p)

    # creating a new process when an old one dies
    while len(params) > 0:
        dead_workers_indexes = [i for i in range(len(workers)) if (workers[i].poll() is not None)]
        for i in dead_workers_indexes:
            worker = workers[i]
            worker.communicate()
            jobs_finished_so_far += 1
            if len(params) > 0:
                curr_params = params.pop(0)
                main_program_name = curr_params['main_program_name']
                curr_params.pop('main_program_name')
                p = run_experiment(curr_params, main_program_name, run_on_slurm=run_on_slurm,
                                   cpus=cpus, gpus=gpus)
                workers[i] = p
                if jobs_finished_so_far % processes_to_run_in_parallel == 0:
                    print(f"finished so far: {jobs_finished_so_far}, {len(params)} jobs left")
            time.sleep(10)

    # joining all last proccesses
    for worker in workers:
        worker.communicate()
        jobs_finished_so_far += 1

    print("finished all")
