import time 
wtime_start = time.time()


import argparse 
import run_toy_noise


base_parser = run_toy_noise.build_parser()
#=== Adding steppable parametres to the base parser ===
#let's check for various min/max values
steppable_parameters = [('P_noise_scale',float),('replicate',int)]
flaggable_parameters = []
list_parameters = []
parameter_types = {param_name:param_type for param_name,param_type in steppable_parameters}
for param,param_type in steppable_parameters:
    base_parser.add_argument(f'--{param}_min',type=param_type,default=None,help=f'Minimum value for {param}.')
    base_parser.add_argument(f'--{param}_max',type=param_type,default=None,help=f'Maximum value for {param}.')
    base_parser.add_argument(f'--{param}_Nstep',type=int,default=None,help=f'Number of steps for {param}.')
    base_parser.add_argument(f'--{param}_geom',action='store_true',help=f'Vary the steps of {param} geometrically.')

for param in flaggable_parameters:
    parameter_types[param] = bool
    base_parser.add_argument(f'--{param}_vary',action='store_true',help=f'If set, will test {param} with both True and False.')

for param,param_type in list_parameters:
    parameter_types[param] = param_type
    base_parser.add_argument(f'--{param}_list', nargs='+', type=param_type, default=None, help=f'List of values for {param}.')

#splitting the job across multiple runs.
base_parser.add_argument('--N_jobs',type=int,default=1,help='Number of parallel jobs')
base_parser.add_argument('--job_index',type=int,default=0,help='Index of the current job')
base_parser.add_argument('--min_job_index',type=int,default=0,help='Index of the first worker -- relevant if we are running a slurm job array, where we start at a nonzero worker index.')
#sometimes we only want to run a subset of the indices, say to restart a failed or incomplete job.
base_parser.add_argument('--start_permutation_index',type=int,default=0,help='Index of the first permutation to run.')
base_parser.add_argument('--end_permutation_index',type=int,default=None,help='Index of the last permutation to run. If None, defaults to the last permutation.')
#debugging the parameters to sweep through.
base_parser.add_argument('--debug_params',action='store_true',help='If set, will print out the parameter grid and exit.')

#actually parsing the arguments.
args=base_parser.parse_args()




#=== Setting up the parameter grid
import numpy as np 
param_grid = {}
for param,param_type in steppable_parameters:
    if(getattr(args,f'{param}_min') is not None):
        assert(getattr(args,f'{param}_max') is not None)
        assert(getattr(args,f'{param}_Nstep') is not None)
        spacer_func = np.geomspace if getattr(args,f'{param}_geom') else np.linspace 
        param_grid_func = (lambda *args : np.round(spacer_func(*args)).astype(param_type)) if param_type == int else (lambda *args : spacer_func(*args).astype(param_type) )
        param_grid[param] = param_grid_func(getattr(args,f'{param}_min'),getattr(args,f'{param}_max'),getattr(args,f'{param}_Nstep'))
for param in flaggable_parameters: 
    if(getattr(args,f'{param}_vary')):
        param_grid[param] = np.array([True,False],dtype=bool)
for param,param_type in list_parameters:
    if(getattr(args,f'{param}_list') is not None):
        param_grid[param] = getattr(args,f'{param}_list')

import itertools
permutations = list(itertools.product( *param_grid.values() ))
N_permutations = len(permutations)
print('number of runnable permutations: ',N_permutations)

#selecting which permutations this iteration will run: 
if(args.end_permutation_index is None):
    args.end_permutation_index = N_permutations
start_indices = np.linspace(args.start_permutation_index,args.end_permutation_index,args.N_jobs+1).astype(int)
inds_to_run = np.arange(start_indices[args.job_index-args.min_job_index],start_indices[args.job_index+1-args.min_job_index])
print('this thread will run indices: ',inds_to_run)
permutation_indices = np.arange(N_permutations)
permutations_to_run = permutations[start_indices[args.job_index-args.min_job_index]:start_indices[args.job_index+1-args.min_job_index]]
if(args.debug_params):
    print(param_grid)
    print(permutations)
    print('this thread will run permuations: ',permutations_to_run)

#=== actually running the permutations.
for i,permutation_ind in enumerate(inds_to_run):
    permutation = permutations_to_run[i]
    print('running permutation',i,'/',len(permutations_to_run),flush=True)
    for param, value in zip(param_grid.keys(),permutation):
        # if(args.debug_params):
        value = hasattr(value,'item') and value.item() or value
        print('setting: ',param,parameter_types[param](value),flush=True)
        setattr(args,param,parameter_types[param](value))
    if(not args.debug_params):
        args = run_toy_noise.postprocess_args(args)
        args.output_filename_prefix=f'perm_{permutation_ind}_'
        training_output = run_toy_noise.run_experiment(args,int(permutation_ind))#,time.time()-wtime_start)

    print('finished permutation',i,'/',len(permutations_to_run),flush=True)
    print('===')