
import argparse
import json
import os
import time

INDEX_BS = 0
INDEX_LR = 1
INDEX_WARMUP = 2
INDEX_MAX_STEPS = 3
INDEX_MAX_LENGTH = 4
TASK_DICT = {
    'rte': (16, 2e-5, 122, 2036, 512),
    'mrpc': (16, 1e-5, 137, 2296, 512),
    'cola': (16, 1e-5, 320, 5336, 512),
    'stsb': (16, 2e-5, 214, 3598, 512),
    'sst2': (32, 1e-5, 1256, 20935, 512),
    'qnli': (32, 1e-5, 1986, 33112, 128),
    'qqp':  (32, 1e-5, 28318, 113272, 128),
    'mnli': (32, 1e-5, 7432, 123873, 128),
}

def run_cmd(cmd):
    print(cmd)
    return os.system(cmd)


def run_glue(args):
    MAIN_PORT = args.device + 25900
    if args.task_name is not None:
        # run single task
        task = args.task_name
        exp_name = task + '-' + args.exp_name
        cmd = (f"MAIN_PORT={MAIN_PORT} CUDA_VISIBLE_DEVICES={args.device} sh run_single_task.sh {task} results/{exp_name} {TASK_DICT[task][INDEX_BS]} {TASK_DICT[task][INDEX_LR]} {TASK_DICT[task][INDEX_WARMUP]} "
               f"{TASK_DICT[task][INDEX_MAX_STEPS]} {TASK_DICT[task][INDEX_MAX_LENGTH]} "
               f"{args.enable} {args.pb} {args.gb} {args.mb} {args.sqmb} "
               f"{args.truncated_mode} {args.truncated_factor} {args.truncated_global_factor} "
               f"{args.scale_type} {args.round_type} {args.q_oracle} {args.mode}"
        )
        run_cmd(cmd)
        return

    # args.task_name is None, run all tasks
    for task in TASK_DICT:
        exp_name = task + '-' + args.exp_name
        cmd = (f"MAIN_PORT={MAIN_PORT} CUDA_VISIBLE_DEVICES={args.device} sh run_single_task.sh {task} results/{exp_name} {TASK_DICT[task][INDEX_BS]} {TASK_DICT[task][INDEX_LR]} {TASK_DICT[task][INDEX_WARMUP]} "
               f"{TASK_DICT[task][INDEX_MAX_STEPS]} {TASK_DICT[task][INDEX_MAX_LENGTH]} "
               f"{args.enable} {args.pb} {args.gb} {args.mb} {args.sqmb} "
               f"{args.truncated_mode} {args.truncated_factor} {args.truncated_global_factor} "
               f"{args.scale_type} {args.round_type} {args.q_oracle} {args.mode}"
        )
        run_cmd(cmd)


def run_ablation(args):
    MAIN_PORT = args.device + 25900
    if args.task_name is None:
        print(f"[error] abalation requires running with specified task.")
        return

    task = args.task_name
    settings = []
    for scale in ['group128', 'group2048', 'sm3', 'dim0', 'dim1', 'tensor']:
        settings.append((scale, args.q_oracle))
    # for scale in ['group2048']:
    #     for mapping in ['nonlinear', 'power-1', 'power-2']:
    #         settings.append((scale, mapping))
    # for scale in ['group128', 'group2048', 'sm3', 'dim0', 'dim1', 'tensor']:
    #     for mapping in ['float-point', 'power-3']:
    #         settings.append((scale, mapping))

    print(settings)
    for setting in settings:
        scale, mapping = setting
        for i in range(10):
            exp_name = f"{task}-{scale}-{mapping}-{args.mb}bit-{args.exp_name}-{i}"
            cmd = (f"MAIN_PORT={MAIN_PORT} CUDA_VISIBLE_DEVICES={args.device} sh run_single_task.sh {task} results/{exp_name} {TASK_DICT[task][INDEX_BS]} {TASK_DICT[task][INDEX_LR]} {TASK_DICT[task][INDEX_WARMUP]} "
                    f"{TASK_DICT[task][INDEX_MAX_STEPS]} {TASK_DICT[task][INDEX_MAX_LENGTH]} "
                    f"{args.enable} {args.pb} {args.gb} {args.mb} {args.sqmb} "
                    f"none none none "
                    f"{scale} {args.round_type} {mapping} {args.mode}"
            )
            run_cmd(cmd)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--exp_name", type=str, required=True)
    parser.add_argument('--enable', type=int, default=15)
    parser.add_argument('--pb', type=int, default=8, help='parameter compression bits')
    parser.add_argument('--gb', type=int, default=4, help='gradient compression bits')
    parser.add_argument('--mb', type=int, default=4, help='momentum compression bits')
    parser.add_argument('--sqmb', type=int, default=4, help='square momentum compression bits')
    parser.add_argument('--truncated-mode', type=str, default='none')
    parser.add_argument('--truncated-factor', type=float, default=-1)
    parser.add_argument('--truncated-global-factor', type=float, default=-1)
    parser.add_argument('--task_name',type=str, default=None)
    parser.add_argument("--scale_type", type=str, default='none')
    parser.add_argument("--round_type", type=str, default='none')
    parser.add_argument("--q_oracle", type=str, default='none')
    parser.add_argument("--mode", type=str, default='base')
    parser.add_argument("--script_mode", type=str, default='main', choices=['main', 'ablation'])
    args = parser.parse_args()
    if args.script_mode == 'main':
        run_glue(args)
    elif args.script_mode == 'ablation':
        run_ablation(args)
    else:
        raise NotImplemented
