from subprocess import call
import sys


experiment = sys.argv[1]
name = sys.argv[2]

dataroot = '/path/to/imagenet/datasets/'

if experiment == 'imagenetc':
    corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise',
                   'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
                   'snow', 'frost', 'fog', 'brightness',
                   'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
    levels = [1, 2, 3, 4, 5]
elif experiment == 'imagenetr':
    corruptions = ['rendition']
    levels = [0]
elif experiment == 'imageneta':
    corruptions = ['adversarial']
    levels = [0]

if name == 'rvt':
    model_tag = '--use_rvt'
    optimizer = 'adamw'
    lr = 0.00001
    weight_decay = 0.01
else:
    model_tag = ''
    optimizer = 'sgd'
    lr = 0.00025
    weight_decay = 0.0

for corruption in corruptions:
    for level in levels:
        print(corruption, 'level', level)
        call(' '.join(['python', 'test_calls/test_initial.py',
                       '--dataroot %s' %(dataroot),
                       model_tag,
                       '--level %d' %(level),
                       '--corruption %s' %(corruption),
                       '--resume results/imagenet_%s/' % (name)]),
             shell=True)

        call(' '.join(['python', 'test_calls/test_adapt.py',
                       '--dataroot %s' %(dataroot),
                       model_tag,
                       '--level %d' %(level),
                       '--corruption %s' %(corruption),
                       '--resume results/imagenet_%s/' % (name),
                       '--optimizer	%s' %(optimizer),
                       '--lr %f' %(lr),
                       '--weight_decay %f' %(weight_decay)]),
             shell=True)
