import tensorflow as tf
import click
import os
import yaml

def mix_data(path_list, size_list, save_path):

    for i, path in enumerate(path_list):
        d = tf.data.experimental.load(path)
        d = d.take(size_list[i])
        if i == 0:
            dataset = d
        else:
            dataset = dataset.concatenate(d)

    tf.data.experimental.save(dataset, save_path)
    return dataset
    

@click.command()
@click.option('--config', '-c', default='generate_mix_data', help='config file name')
@click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str]))
def main(config, options):

    cwd = os.getcwd()
    cfg_file = os.path.join(cwd, config + '.yaml')
    params = yaml.safe_load(open(cfg_file, 'r'))

    # replacing params with command line options
    for opt in options:
        assert opt[0] in params
        dtype = type(params[opt[0]])
        if dtype == bool:
            new_opt = False if opt[1] != 'True' else True
        else:
            new_opt = dtype(opt[1])
        params[opt[0]] = new_opt

    # steps_dict = {'cartpole': {'med': 200, 'exp': 500},
    #                 'catch': {'med': 200, 'exp': 1000}}
    if params['env_id'] == 'catch':
        params['data_size'] = params['data_size'] // 10

    if params['mix_type'] in ['med_seed', 'exp_seed']:    
        size_list = [params['data_size'] // 5] *5
    elif params['mix_type'] in ['uni_med', 'uni_exp']: 
        size_list = [params['data_size'] // 2] *2
    elif params['mix_type'] in ['uni', 'med', 'exp']:
        size_list = [params['data_size']]
    else:
        raise NotImplementedError

    def get_path(policy_type, seed):
        # if policy_type == 'uni':
        #     policy_name = 'uniform_' + str(seed)
        # else:
        #     policy_name = 'dqn_' + str(steps_dict[params['env_id']][policy_type]) + '_' + str(seed)

        policy_name = policy_type + '_' + str(seed)
        
        return os.path.join(params['save_path'], 
                            params['env_id'] + '_' + str(params['env_noise']), 
                            policy_name)

    if params['mix_type'] == 'uni_med':
        path_list = [get_path('uni', params['seed']), get_path('med', params['seed'])]
    elif params['mix_type'] == 'uni_exp':
        path_list = [get_path('uni', params['seed']), get_path('exp', params['seed'])]
    elif params['mix_type'] == 'med_seed':
        path_list = [ get_path('med', params['seed'] + i) for i in range(5)]
    elif params['mix_type'] == 'exp_seed':
        path_list = [ get_path('exp', params['seed'] + i) for i in range(5)]
    elif params['mix_type'] in ['uni', 'med', 'exp']:
        path_list = [ get_path(params['mix_type'], params['seed']) ]
    
    save_path = os.path.join(params['save_path'], 
                            params['env_id'] + '_' + str(params['env_noise']), 
                            params['mix_type'] + '_' + str(params['seed'])
                            + '_' + str(params['data_size'] // 1000) + 'k')

    mix_data(path_list, size_list, save_path)

if __name__ == '__main__':
    main()
