import os

backward_passes_per_step = 1
ckpt_interval = 1000
debug = 1
list_batch_size = 410
list_dataset = "LAION400M"
list_margin = 0.3
list_sample_rate = [0.1]
list_filter = 0.75
list_lr_pfc_weight = 1
embedding_size = 768
lr = 0.001
num_epochs = 32
list_num_class = [1000000]
list_num_class = ",".join([str(x) for x in list_num_class])
list_sample_rate = ",".join([str(x) for x in list_sample_rate])
image_size = 224

opt = "adamw"
output = f"./models/{os.path.basename(__file__)}"

random_diff = 10
repeat_pfc = 3
save_pfc = 1
frequent = 2
warmup_ratio = 0.002
weight_decay = 0.2
workers = 32

ip_list = [
    "localhost"
    # "172.16.9.10",
    # "172.16.9.11",
    # "172.16.9.12",
    # "172.16.9.13",
    # "172.16.9.14",
    # "172.16.9.15",
    # "172.16.9.16",
    # "172.16.9.17",
    # "172.16.9.18",
    # "172.16.9.19",
]
port = 39999

for ip in ip_list:
    cmd = f"ssh root@{ip}"
    cmd += " '"
    cmd += f"cd {os.getcwd()};"
    cmd += f" PATH={os.environ['PATH']}"
    cmd += " NCCL_SOCKET_IFNAME=eth0"
    cmd += " NCCL_SOCKET_NTHREADS=4"
    cmd += " NCCL_NSOCKS_PERTHREAD=4"
    cmd += " NCCL_ALGO=Ring"
    cmd += " USE_CHECKPOINT=1"
    cmd += " CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 "
    cmd += " torchrun --nproc_per_node 8"
    cmd += f" --nnodes {len(ip_list)}"
    cmd += f" --node_rank {ip_list.index(ip)}"
    cmd += f" --master_addr {ip_list[0]}"
    cmd += f" --master_port {port}"
    cmd += " train.py"
    cmd += f" --backward_passes_per_step {backward_passes_per_step}"
    cmd += f" --ckpt_interval {ckpt_interval}"
    cmd += f" --debug {debug}"
    cmd += f" --list_batch_size {list_batch_size}"
    cmd += f" --list_dataset {list_dataset}"
    cmd += f" --list_filter {list_filter}"
    cmd += f" --list_margin {list_margin}"
    cmd += f" --list_num_class {list_num_class}"
    cmd += f" --list_sample_rate {list_sample_rate}"
    cmd += f" --list_lr_pfc_weight {list_lr_pfc_weight}"
    cmd += f" --image_size {image_size}"
    cmd += f" --embedding_size {embedding_size}"
    cmd += f" --lr {lr}"
    cmd += f" --num_epochs {num_epochs}"
    cmd += f" --opt {opt}"
    cmd += f" --output {output}"
    cmd += f" --random_diff {random_diff}"
    cmd += f" --save_pfc {save_pfc}"
    cmd += f" --frequent {frequent}"
    cmd += f" --warmup_ratio {warmup_ratio}"
    cmd += f" --weight_decay {weight_decay}"
    cmd += f" --workers {workers}"
    cmd += "' &"
    print(cmd)
    os.system(cmd)
