
import argparse
import json
import os
import time

INDEX_BS = 0
INDEX_LR = 1
INDEX_WARMUP = 2
INDEX_MAX_STEPS = 3
INDEX_MAX_LENGTH = 4
TASK_DICT = {
    'rte': (16, 2e-5, 122, 2036, 512),
    'mrpc': (16, 1e-5, 137, 2296, 512),
    'cola': (16, 1e-5, 320, 5336, 512),
    'stsb': (16, 2e-5, 214, 3598, 512),
    'sst2': (32, 1e-5, 1256, 20935, 512),
    'qnli': (32, 1e-5, 1986, 33112, 128),
    'qqp':  (32, 1e-5, 28318, 113272, 128),
    'mnli': (32, 1e-5, 7432, 123873, 128),
}

def run_cmd(cmd):
    print(cmd)
    return os.system(cmd)


def get_device_ids(num_devices, start_device):
    ids = list(range(start_device, start_device + num_devices))
    for i, _ in enumerate(ids):
        ids[i] = str(ids[i])
    return ','.join(ids), len(ids)

def run_batch(args):
    args.num_devices = int(args.num_devices)
    args.start_device = int(args.start_device)
    device_ids, num_devices = get_device_ids(args.num_devices, args.start_device)
    main_port = (args.start_device) + 29500
    for bs in [8]:
        for lr_idx, lr in enumerate([1e-5, 2e-5, 3e-5, 1e-4, 5e-4, 1e-3]):
            if lr_idx in [1]:
                continue
            exp_name = args.exp_name + f'-bs{bs * args.num_devices}-lr{lr_idx}'
            cmd = (f"MAIN_PORT={main_port} CUDA_VISIBLE_DEVICES={device_ids} NUM_GPUS={num_devices} "
                    f"sh run.sh ./results/{exp_name} {bs} {lr} "
                    f"{args.enable} {args.mb} {args.sqmb} "
                    f"{args.scale_type} {args.round_type} {args.q_oracle} " 
            )
            run_cmd(cmd)


def run_lpmm(args):
    args.num_devices = int(args.num_devices)
    args.start_device = int(args.start_device)
    device_ids, num_devices = get_device_ids(args.num_devices, args.start_device)
    main_port = (args.start_device) + 29500
    for bs in [12]: # 12 * 4 = 48
        for lr_idx, lr in enumerate([1.5e-5]):
            exp_name = args.exp_name
            cmd = (f"MAIN_PORT={main_port} CUDA_VISIBLE_DEVICES={device_ids} NUM_GPUS={num_devices} DATA={args.data} "
                    f"sh run.sh ./results/{exp_name} {lr} "
                    f"{args.enable} {args.pb} {args.mb} {args.sqmb} "
                    f"{args.scale_type} {args.round_type} {args.q_oracle} {args.mode} " 
            )
            run_cmd(cmd)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_devices", type=int, default=1)
    parser.add_argument("--start_device", type=int, default=0)
    parser.add_argument("--exp_name", type=str, required=True)
    parser.add_argument("--data", type=str, default='squad')
    parser.add_argument('--enable', type=int, default=0)
    parser.add_argument('--pb', type=int, default=8, help='parameter compression bits')
    parser.add_argument('--mb', type=int, default=4, help='momentum compression bits')
    parser.add_argument('--sqmb', type=int, default=4, help='square momentum compression bits')
    parser.add_argument("--scale_type", type=str, default='none')
    parser.add_argument("--round_type", type=str, default='none')
    parser.add_argument("--q_oracle", type=str, default='none')
    parser.add_argument("--mode", type=str, default='base')
    args = parser.parse_args()
    # run_batch(args)
    '''example script
    # python run.py --exp_name test --enable 12 --mb 4 --sqmb 4 --data squad
    '''
    run_lpmm(args)
