import argparse
import itertools
import os
from constants import debias_methods, dataset_names, params, reduced_params


parser = argparse.ArgumentParser(description='config')
parser.add_argument('--run', default='train', type=str)
parser.add_argument('--init_seed', default=0, type=int)
parser.add_argument('--round', default=16, type=int)
parser.add_argument('--experiment_step', default=1, type=int)
parser.add_argument('--num_experiments', default=1, type=int)
parser.add_argument('--split_round', default=65536, type=int)
parser.add_argument('--mode', default='baseline', type=str)
parser.add_argument('--dataset', default='Adult', type=str)
parser.add_argument('--multi_group', action='store_true', default=False)



args = vars(parser.parse_args())


def make_controls(script_name, init_seeds, num_experiments, control_name, params, num_groups=None):
    control_names = []
    for group in control_name:
        datasets, models, methods = group
        for dataset in datasets:
            for model in models:
                for method in methods:
                    if num_groups:
                        for num_group in num_groups:
                            # Pair method with its parameter series
                            method_params = params.get(method, ['none'])
                            control_names.extend([f"{dataset}_{model}_{method}_{param}_{num_group}" for param in method_params])
                    else:
                        # Pair method with its parameter series
                        method_params = params.get(method, ['none'])
                        control_names.extend([f"{dataset}_{model}_{method}_{param}" for param in method_params])
    control_names = [control_names]
    controls = script_name + init_seeds + num_experiments + control_names
    controls = list(itertools.product(*controls))
    return controls


def main():
    run = args['run']
    round = args['round']
    experiment_step = args['experiment_step']
    init_seed = args['init_seed']
    num_experiments = args['num_experiments']
    split_round = args['split_round']
    mode = args['mode']
    data_name = args['dataset']
    script_path = os.path.join('output', 'script')
    init_seeds = [list(range(init_seed, init_seed + num_experiments, experiment_step))]
    num_experiments = [[experiment_step]]
    filename = '{}_{}'.format(run, mode)
    if mode == 'baseline':
        script_name = [['{}_baseline.py'.format(run)]]
        model_name = ['baseline']
        if data_name in dataset_names['b_clf']:
            target = 'b_clf'
        elif data_name in dataset_names['nb_clf']:
            target = 'nb_clf'
        elif data_name in dataset_names['reg']:
            target = 'reg'
        else:
            raise ValueError(f'Not valid dataset: {data_name}')
        debias_model = debias_methods[target]
        data_name_lst = [data_name]
        control_name = [[data_name_lst, model_name, debias_model]]
        if args['multi_group']:
            if data_name == 'AdultM':
                num_groups = [2, 10, 20, 30, 40, 50]
            else:
                num_groups = [2**(i+1) for i in range(10)]
            controls = make_controls(script_name, init_seeds, num_experiments, control_name, reduced_params, num_groups)
        else:
            controls = make_controls(script_name, init_seeds, num_experiments, control_name, params)
    
        
    else:
        raise ValueError('Not valid mode')
    s = '#!/bin/bash\n'
    j = 1
    k = 1

    for i in range(len(controls)):
        controls[i] = list(controls[i])
        
        s = s + 'python {} --init_seed {} --num_experiments {} ' \
                    '--control_name {}&\n'.format(*controls[i])
        
        if s != '#!/bin/bash\n':
            if i % round == round - 1:
                s = s[:-2] + '\nwait\n'
                if j % split_round == 0:
                    print(s)
                    if not os.path.exists(script_path):
                        os.makedirs(script_path)
                    run_file = open(os.path.join(script_path, '{}_{}_{}.sh'.format(filename, data_name, k)), 'w')
                    run_file.write(s)
                    run_file.close()
                    s = '#!/bin/bash\n'
                    k = k + 1
                j = j + 1
    if s != '#!/bin/bash\n':
        if s[-5:-1] != 'wait':
            s = s + 'wait\n'
        print(s)
        if not os.path.exists(script_path):
            os.makedirs(script_path)
        run_file = open(os.path.join(script_path, '{}_{}_{}.sh'.format(filename, data_name, k)), 'w')
        run_file.write(s)
        run_file.close()
    return


if __name__ == '__main__':
    main()