import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
# keep training arguments fixed
#parser.add_argument('--lr', default=2.5e-3, type=float, help='learning rate for intermediate finetune')
#parser.add_argument('--batch_size', default=128, type=int, help='batch size')  # default 128
#parser.add_argument('--lr_type', default='fixed', type=str, help='lr scheduler (exp/cos/step3/fixed)')
#parser.add_argument('--n_epoch', default=0.1, type=float, help='number of epochs for intermediate finetune')
#parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
#parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
#parser.add_argument('--num_evaluate', default=50, type=int)

# settings for pruning
model_fp = ['./model_checkpoints/modelmobilenetv2_size50000_model_final.pth']
output_fp = []
for fp in model_fp:
    dest = fp.split('/')[-1][:-4] + '_finetuned_model'
    dest = os.path.join('./finetuned_models', dest)
    output_fp.append(dest)
n_gpu = 1
model = 'mobilenet'
dataset = 'cifar10'
dataroot = './data/'
seed = 2
num_eval = 50
top1_tols = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1]
lrs = [3e-4, 1e-3, 3e-3, 1e-2]
if len(top1_tols) != len(output_fp):
    output_fp *= len(top1_tols)
skip_eval_conv = 0.05

# run pruning
for dest, top1_tol in zip(output_fp, top1_tols):
    for lr in lrs:
        print(f'\n\nTesting tol {top1_tol}, lr {lr}, {dest}')
        fn = f'mbv2_ft_{top1_tol}.pth.tar'
        full_dest = os.path.join(dest + f'_{lr}', fn)
        command = (
                f'python mobile2_prune.py --model {model} --dataset {dataset} --data_root {dataroot}'
                f' --seed {seed} --n_gpu {n_gpu} --num_evaluate {num_eval} --load_path {full_dest}'
                f' --eval --top1_tol {top1_tol} --skip_eval_converge {skip_eval_conv}'
                f' --isfullnetpruned 1'
        )
        os.system(command)
        print(f'Test Done!\n\n')
