import subprocess
import time

# Define the learning rates to sweep over
# learning_rates = [1e-3, 3e-3, 3e-4, 5e-4, 7e-4, 9e-4, 2e-5, 5e-5, 7e-5]
# learning_rates = [1e-3, 3e-3, 5e-3, 7e-3, 9e-3, 1e-4, 5e-4, 9e-4]
learning_rates = [9e-4]
# learning_rates = [1e-4, 3e-4, 5e-4, 7e-4, 9e-4]
# learning_rates = [5e-3, 7e-3, 9e-3]
# learning_rates = [2e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5]
# learning_rates = [3e-5]
# learning_rates = [1e-4, 5e-5, 1e-5, 5e-4, 2e-4, 3e-4, 7e-4]
# learning_rates = [5e-5, 1e-5, 1e-4, 3e-4]
# learning_rates = [1e-4, 5e-4, 2e-4, 3e-4, 7e-4]

# # Function to check if a Slurm job is done
# def is_job_done(job_id):
#     result = subprocess.run(['sacct', '-j', job_id, '--format=State', '--noheader'],
#                             stdout=subprocess.PIPE, text=True)
#     state = result.stdout.strip()
#     return state in ['COMPLETED', 'FAILED', 'CANCELLED']

# Submit the jobs one by one and wait for each to complete
for i in range(5):
    for wd in [0.1,  0.0001]:
        for lr in learning_rates:
            # Construct the command with the current learning rate
            command = [
                'python', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/launch_frontier.py',
                '--python_script=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/train_retrieval_w_anticausal.py',
                '--rccl_installdir=${WRKSPC}/tiny_plugins_rccl.tar.gz',
                '--environment=${WRKSPC}/frontier_conda_60.tar.gz',
                '--budget_minutes=30',
                '--nodes', '1',
                '--config', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/XXXX-22/frontier_jobs/fineweb_retrieval_train_pythia.json',
                '--run_name', f'siglip_sweep_lr_{str(lr)}_wd_{wd}_warmup_6k_fineweb_100B-retrieval-dual-causal-pythia-160m-mbsz-10-wbsz-80-ctx-var-batch_negative_ddp_RR_max_iters_57691',
                '--extra_args', f'--siglip_loss=true --warmup_steps=6500 --optim_config.weight_decay={wd} --optim_config.lr={lr} --min_lr={lr*0.1} --max_iters=57691 --max_tokens=null --save_n_min_before_job_done=3 --save_step_interval=10000 --eval_step_interval=10000 --micro_batch_size=10 --world_batch_size=80  --fabric_strategy="ddp"',
                '--disable_net_gdr',
                '--sub_output_dir_name', f'siglip_sweep_lr_{str(lr)}_wd_{wd}_warmup_6k_fineweb_100B-retrieval-dual-causal-pythia-160m-mbsz-10-wbsz-80-ctx-var-batch_negative_ddp_RR_max_iters_57691',
                '--debug_qos'
            ]
            # command = [
            #     'python', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/launch_frontier.py',
            #     '--python_script=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/train_retrieval_w_anticausal.py',
            #     '--rccl_installdir=${WRKSPC}/tiny_plugins_rccl.tar.gz',
            #     '--environment=${WRKSPC}/frontier_conda_60.tar.gz',
            #     '--budget_minutes=30',
            #     '--nodes', '1',
            #     '--config', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/XXXX-22/frontier_jobs/retrieval_train_pythia.json',
            #     '--run_name', f'sweep_lr_{str(lr)}_dolma-retrieval-dual-causal-pythia-160m-mbsz-24-wbsz-192-ctx-var-batch_negative_ddp_RR_cum_mean',
            #     '--extra_args', f'--optim_config.lr={lr} --min_lr={lr*0.1} --save_n_min_before_job_done=3 --save_step_interval=10000 --eval_step_interval=10000 --micro_batch_size=24 --world_batch_size=192 --mean_pooling=True',
            #     '--sub_output_dir_name', f'sweep_lr_{str(lr)}_dolma-retrieval-dual-causal-pythia-160m-mbsz-24-wbsz-192-ctx-var-batch_negative_ddp_RR_cum_mean',
            #     '--debug_qos'
            # ]
            # command = [
            #     'python', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/launch_frontier.py',
            #     '--python_script=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/pretrain_umd/train.py',
            #     '--rccl_installdir=${HOME}/tiny_plugins_rccl/lib',
            #     '--env_packed=${HOME}/frontier_conda_env_packed.tar.gz',
            #     '--budget_minutes=30',
            #     '--nodes', '1',
            #     '--config', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/XXXX-22/frontier_jobs/llm_train.json',
            #     '--run_name', f'cosmopedia-lm-pythia-160m-bsz-25_ddp_lr_{str(lr)}_debug',
            #     '--extra_args', f'--learning_rate={lr} --min_lr={lr*0.1} --fabric_strategy="ddp"',
            #     '--sub_output_dir_name', f'cosmopedia-lm-pythia-160m-bsz-25_ddp_lr_{str(lr)}_debug',
            #     '--debug_qos'
            # ]
                # command = [
                #     'python', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/launch_frontier.py',
                #     '--python_script=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/pretrain_umd/train_retrieval_w_anticausal.py',
                #     '--rccl_installdir=${HOME}/tiny_plugins_rccl/lib',
                #     '--env_packed=${HOME}/frontier_conda_env_packed.tar.gz',
                #     '--budget_minutes=30',
                #     '--nodes', '1',
                #     '--config', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/XXXX-22/frontier_jobs/retrieval_train_pythia.json',
                #     '--run_name', f'dolma-retrieval-dual-causal-pythia-160m-bsz-25-ctx-rand-batch_negative_ddp_RR_lr_{str(lr)}_warmup_4k_{lr_scheduler}_debug',
                #     '--extra_args', f'--save_and_eval_interval=1000 --learning_rate={lr} --min_lr={lr*0.1} --fabric_strategy="ddp" --warmup_steps=4000 --lr_schedule={lr_scheduler}',
                #     '--sub_output_dir_name', f'dolma-retrieval-dual-causal-pythia-160m-bsz-25-ctx-rand-batch_negative_ddp_RR_lr_{str(lr)}_warmup_4k_{lr_scheduler}_debug',
                #     '--debug_qos'
                # ]

            # Print the command (for debugging purposes)
            print("Running command:", ' '.join(command))

            # Execute the command
            subprocess.run(command)

            # Wait for 33 minutes before launching the next job
            print("#"*30)
            print(f"{i} - Waiting for 33 minutes before launching the next job...")
            print("#"*30)
            time.sleep(33 * 60)  # 33 minutes in seconds
    print("#"*30)
    print("Running the jobs again...")
    print("#"*30)

print("All jobs completed!")

# for _ in range(4):
#     command = [
#         'python', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/launch_frontier.py',
#         '--python_script=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/pretrain_umd/train_retrieval_w_anticausal.py',
#         '--rccl_installdir=${HOME}/tiny_plugins_rccl/lib',
#         '--env_packed=${HOME}/frontier_conda_env_packed.tar.gz',
#         '--budget_minutes=30',
#         '--nodes', '1',
#         '--config', '/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/XXXX-22/frontier_jobs/ddp_pp.json',
#         '--extra_args', '--learning_rate=1e-5 --min_lr=1e-6 --fabric_strategy="ddp" --pretrained_prefix_model=True --pretrained_suffix_model=True',
#         '--run_name', 'concatedOrca-retrieval-dual-causal-llama-1.1b-bsz-2-ctx-rand-batch_negative_ddp_PP_lr_1e-5',
#         '--sub_output_dir_name', 'concatedOrca-retrieval-dual-causal-llama-1.1b-bsz-2-ctx-rand-batch_negative_ddp_PP_lr_1e-5',
#         '--debug_qos'
#     ]
