import os
from itertools import product
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("device", type=int,
                    help="input the integer corresponding to the GPU")
args = parser.parse_args()
device = args.device


envs = ['quadruped']
tasks = ['walk', 'run']
orl_algs = ['td3', 'cql', 'crr']
seeds = [1, 1087, 1604, 1776, 2040] # proxy for number of passes to do

data_gen_algs = ['icm_apt_gbe_alpha', 'icm_apt_renyi_q']
data_param_map = {
    'icm_apt_gbe_alpha': ['2e-1', '5e-1', '1e0', '3e0', '5e0'],
    'icm_apt_renyi_q': ['0.2', '0.5', '0.7', '0.9', '1.1']
}
device_map = {
    0: ('icm_apt_gbe_alpha', 'walk'),
    1: ('icm_apt_gbe_alpha', 'run'),
    2: ('icm_apt_renyi_q', 'walk'),
    3: ('icm_apt_renyi_q', 'run')
}


def main():
    i = 0

    for seed in seeds:
        for env, task, orl_alg, data_gen_alg in product(
            envs, tasks, orl_algs, data_gen_algs
        ):
            desired_data_gen_alg, desired_task = device_map[device]
            if desired_data_gen_alg == data_gen_alg and desired_task == task:
                for param in data_param_map[data_gen_alg]:
                    command = f'python3 train_offline.py ++seed={seed} agent={orl_alg} ' + \
                        f'expl_agent={data_gen_alg}_{param} task={env}_{task}'
                    print(command)
                    i += 1
                    print(f'Experiment #{i} running...')
                    os.system(command)
        
        print(f'Experiments complete!')
        

if __name__ == '__main__':
    main()