import os
import sys
import yaml
import copy
import numpy as np
import argparse

yaml_template = {
  'gamma': None,
  'server_params': {}
}

run_template = 'source setup.sh && python3 -m simulator.run_experiment --config {config_name} --save_path {dumps_path}/{experiments_name}/results/{result_name}'

def generate_yaml(args):
    os.mkdir(os.path.abspath('{dumps_path}/{experiments_name}'.format(dumps_path=args.dumps_path, experiments_name=args.experiments_name)))
    source_dir = os.path.abspath('{dumps_path}/{experiments_name}/source_folder'.format(dumps_path=args.dumps_path, experiments_name=args.experiments_name))
    os.mkdir(source_dir)
    print("Source dir: {}".format(source_dir))
    run_dir = os.path.abspath('{dumps_path}/{experiments_name}/run_folder'.format(dumps_path=args.dumps_path, experiments_name=args.experiments_name))
    os.mkdir(run_dir)
    print("Run dir: {}".format(source_dir))
    print(run_dir)
    results_dir = os.path.abspath('{dumps_path}/{experiments_name}/results'.format(dumps_path=args.dumps_path, experiments_name=args.experiments_name))
    os.mkdir(results_dir)
    open(os.path.join(source_dir, 'cmd.txt'), 'w').write(" ".join(sys.argv))
    generator = np.random.default_rng(seed=42)
    index = 0

    for _, (algorithm_name) in enumerate(args.algorithm_names):
        for gamma in [2**i for i in range(step_size_range[0], step_size_range[1])]:
            if algorithm_name in ['ringmaster_sgd', 'rennala_sgd', 'rennala_softreduce_sgd', 'rennala_sgd_history_window', 'local_sgd', 'ringmaster_sgd_compcomm', 'subset_ring_reduce']:
                num_algorithm_configs = len(args.num_grads_list)
            else:
                num_algorithm_configs = 1
                
            for algo_config_index in range(num_algorithm_configs):
                yaml_prepared = copy.deepcopy(yaml_template)
                yaml_prepared['gamma'] = gamma
                yaml_prepared['server'] = algorithm_name
                yaml_prepared['sim_time'] = args.sim_time
                yaml_prepared['metric_check_num'] = args.metric_check_num
                yaml_prepared['num_workers'] = args.num_workers
                yaml_prepared['batch_size'] = args.batch_size
                yaml_prepared['optimizer'] = args.optimizer
                
                if args.times_to_calculate_sqrt:
                    times_to_calculate = np.sqrt(np.array(range(args.num_workers)) + 1)
                else:
                    times_to_calculate = generator.choice(args.times_to_calculate, (args.num_workers,))
                    
                if args.communicate_to_calculate_ratio > 0:
                    times_to_communicate = times_to_calculate * args.communicate_to_calculate_ratio
                else:
                    times_to_communicate = generator.choice(args.times_to_communicate, (args.num_workers,))
                    
                yaml_prepared['times_to_calculate'] = times_to_calculate.tolist()
                yaml_prepared['times_to_communicate'] = times_to_communicate.tolist()
                    
                if algorithm_name in ['ringmaster_sgd', 'rennala_sgd', 'rennala_softreduce_sgd', 'rennala_sgd_history_window', 'local_sgd', 'ringmaster_sgd_compcomm', 'subset_ring_reduce']:
                    yaml_prepared['server_params']['num_grads'] = args.num_grads_list[algo_config_index]
                    
                if algorithm_name == 'rennala_sgd_history_window':
                    yaml_prepared['server_params']['history_window'] = args.history_window_rennala

                config_name = os.path.join(source_dir, 'config_{}.yaml'.format(index))
                with open(config_name, 'w') as fd:
                    yaml.dump(yaml_prepared, fd, default_flow_style=None)
                sh_name = os.path.join(run_dir, 'run_{}.sh'.format(index))
                with open(sh_name, 'w') as fd:
                    fd.write(run_template.format(dumps_path=args.dumps_path, experiments_name=args.experiments_name,
                                                 config_name=config_name,
                                                 result_name='config_{}.json'.format(index)))
                index += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dumps_path', required=True)
    parser.add_argument('--experiments_name', required=True)
    parser.add_argument('--num_workers', required=True, type=int)
    parser.add_argument('--step_size_range', required=True, nargs='+', type=int)
    parser.add_argument('--sim_time', type=int, default=20000)
    parser.add_argument('--metric_check_num', type=int, default=1)
    parser.add_argument('--algorithm_names', required=True, nargs='+')
    parser.add_argument('--num_grads_list', nargs='+', type=int)
    parser.add_argument('--times_to_calculate', nargs='+', type=float)
    parser.add_argument('--times_to_calculate_sqrt', action='store_true')
    parser.add_argument('--times_to_communicate', nargs='+', type=float)
    parser.add_argument('--communicate_to_calculate_ratio', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--history_window_rennala', type=int, default=10)
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'])
    args = parser.parse_args()
    assert len(args.step_size_range) == 2
    step_size_range = args.step_size_range
    generate_yaml(args)
