import fire
import collections
import os
import pickle

from sweep_utils import add_train_params, all_vs_all_grid, compose_cmd_args, add_common_validation

def prepare_configuration(sweep_step):
    experiment_name = 'utilization_rate'

    kv_opts = collections.OrderedDict()
    kv_opts['--user-dir'] = '../fairseq_module'
    kv_opts = add_train_params(kv_opts)

    kv_opts['data'] = os.path.join(os.environ.get('DATA'), f'data.tokenized')

    # altering pretrain args to adjust for new experiment
    save_dir = os.environ.get('EXPERIMENTS_DIRECTORY')
    save_dir = os.path.join(save_dir, experiment_name, f'sweep_step_{sweep_step}')
    save_dir_tb = os.path.join(save_dir, 'tb')
    kv_opts['--save-dir'] = save_dir
    kv_opts['--tensorboard-logdir'] = save_dir_tb

    grid = collections.OrderedDict()
    grid['--seed'] = ['2421', '2804', '9361', '4872', '6765']
    grid['--utilization-weight'] = [0.0, 0.25, 0.5, 0.75, 1.0]
    grid['--utilization-type'] = ['u', 'c', 's']

    sweep_step_dict = all_vs_all_grid(grid)[sweep_step-1]

    for k,v in sweep_step_dict.items():
        kv_opts[k] = v

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    cmd_args_filename = os.path.join(save_dir, experiment_name+f'_{sweep_step}'+'_args.pkl')
    pickle.dump(kv_opts, open(cmd_args_filename, 'wb'))

    return kv_opts

def validate_trained_sweep_ontest(sweep_step, experiment_name_to_validate, beam):
    experiment_name = f'validate_beam{beam}_testset'

    pretrain_args_pkl_filename = os.path.join(os.environ.get('EXPERIMENTS_DIRECTORY'), experiment_name_to_validate,
                    f'sweep_step_{sweep_step}', f'{experiment_name_to_validate}_{sweep_step}'+'_args.pkl')
    args_from_trained_model = pickle.load(open(pretrain_args_pkl_filename, 'rb'))

    kv_opts = collections.OrderedDict()
    kv_opts = add_common_validation(kv_opts, args_from_trained_model)

    kv_opts['data'] = os.path.join(os.environ.get('DATA'), f'data.tokenized')

    kv_opts['--max-tokens'] = '256'
    kv_opts['--valid-subset'] = 'test'

    kv_opts['--eval-bleu'] = True
    kv_opts['--eval-bleu-args'] = '\'{"beam": %d, "max_len_a": 1.2, "max_len_b": 10, "min_length": 0, "unnormalized": true}\'' % beam
    kv_opts['--eval-bleu-detok'] = 'moses'

    kv_opts['--stat-save-path'] = os.path.join(args_from_trained_model['--save-dir'], experiment_name, 'best_extra_state.pkl')

    if not os.path.exists(kv_opts['--stat-save-path']):
        os.makedirs(os.path.dirname(kv_opts['--stat-save-path']), exist_ok=True)

    cmd_args_filename = os.path.join(args_from_trained_model['--save-dir'], experiment_name, f'validate_{sweep_step}'+'_args.pkl')
    pickle.dump(kv_opts, open(cmd_args_filename, 'wb'))

    return kv_opts


def main(call_fn, sweep_step, **kwargs):
    call_fn = globals()[call_fn]
    dict_args = call_fn(sweep_step, **kwargs)
    return compose_cmd_args(dict_args)


if __name__ == "__main__":
    fire.Fire(main)
                                                                                                                                                               99,1          Bot
