import pickle
import itertools
import fire
import collections
import os

def add_common_validation(kv_opts: collections.OrderedDict, args_from_trained_model: collections.OrderedDict) -> collections.OrderedDict:
    kv_opts['data'] = args_from_trained_model['data']
    kv_opts['--user-dir'] = '../fairseq_module'
    kv_opts['--task'] = args_from_trained_model['--task']
    kv_opts['--path'] = os.path.join(args_from_trained_model['--save-dir'], 'checkpoint_best.pt')
    kv_opts['--max-tokens'] = '2048'

    return kv_opts

def add_train_params(kv_opts: collections.OrderedDict) -> collections.OrderedDict:
    kv_opts['--task'] = 'summarization_utilization'
    kv_opts['--user-dir'] = '../fairseq_module/'
    kv_opts['--optimizer'] = 'adam'
    kv_opts['--adam-betas'] = '\'(0.9, 0.98)\''
    kv_opts['--lr-scheduler'] = 'inverse_sqrt'
    kv_opts['--warmup-updates'] = '4000'
    kv_opts['--weight-decay'] = '0.0001'
    kv_opts['--criterion'] = 'utilization_nll_loss'
    kv_opts['--max-tokens'] = '4096'
    kv_opts['--arch'] = 'transformer_iwslt14_deen'
    kv_opts['--clip-norm'] = '0.0'
    kv_opts['--lr'] = '5e-4'
    kv_opts['--dropout'] = '0.3'
    kv_opts['--no-epoch-checkpoints'] = True
    kv_opts['--best-checkpoint-metric'] = 'loss'
    kv_opts['--patience'] = '5'
    kv_opts['--validate-interval-updates'] = '2000'
    kv_opts['--label-smoothing'] = '0.1'
    return kv_opts

def all_vs_all_grid(grid_dict: collections.OrderedDict) -> list:
    assert isinstance(grid_dict, collections.OrderedDict), "grid dictionary is supposed to be the ordered dict to avoid "
    sweeps = list(collections.OrderedDict(zip(grid_dict.keys(), values)) for values in itertools.product(*grid_dict.values()))
    return sweeps

def compose_cmd_args(dict_args, newlines=False):
    # first we put data mandatory field

    delimeter = '\n' if newlines else ' '
    trail_char = '\\' if newlines else ''
    cmd_args = [f"{dict_args['data']} {trail_char}"]
    for arg, val in dict_args.items():
        if '--' not in arg:
            continue
        if isinstance(val, bool):
            if val == False:
                # do not include bool arg if it is False
                continue
            else:
                cmd_args.append(f"{arg} {trail_char}")
        else:
            cmd_args.append(f"{arg} {val} {trail_char}")

    cmd_args[-1] = cmd_args[-1].rstrip("{trail_char}")

    return f'{delimeter}'.join(cmd_args)

def load_pkl_args(pkl_args_filename):
    dict_args = pickle.load(open(pkl_args_filename, 'rb'))
    print(compose_cmd_args(dict_args, newlines=True))

def main(call_fn, **kwargs):
    call_fn = globals()[call_fn]
    call_fn(**kwargs)


if __name__ == "__main__":
    fire.Fire(main)
