import os
import argparse
import wandb
from nesim.utils.folder import get_filenames_in_a_folder
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict

parser = argparse.ArgumentParser(description="")
parser.add_argument("--dry-run", action="store_true", help="Run in dry-run mode")
parser.add_argument(
    "--wandb-log", action="store_true", help="Weights and biases logging"
)
parser.add_argument(
    "--num-epochs-per-layer",
    type=int,
    help="number of epochs of ring loss training per layer.",
)

args = parser.parse_args()
dry_run_arg = args.dry_run
wandb_log = args.wandb_log
num_epochs = args.num_epochs_per_layer

# hidden_sizes = [128, 256, 512, 1024, 2048]
hidden_sizes = [1024]

nesim_configs = sorted(get_filenames_in_a_folder("./nesim_configs"))
nesim_configs.remove("./nesim_configs/baseline.json")  # exscluding baseline for now)
resume_data = {"model_path": "None", "wandb_run_id": None}
resume_data["wandb_run_id"] = wandb.util.generate_id()

dict_to_json(resume_data, "resume_from.json")

apply_every_n_steps_values = [1]
if wandb_log is False:
    wandb_log_arg = ""
else:
    wandb_log_arg = "--wandb-log"

for hidden_size in hidden_sizes:
    for nesim_config in nesim_configs:
        resume_data = load_json_as_dict("resume_from.json")
        if nesim_config != "nesim_configs/baseline.json":
            for apply_every_n_steps in apply_every_n_steps_values:
                command = f"python3 train.py --nesim-config {nesim_config} --load-weights-path {resume_data['model_path']} --nesim-apply-after-n-steps {apply_every_n_steps} --num-epochs {num_epochs} {wandb_log_arg}"
                print(f"Running command\n{command}\n")
                if dry_run_arg is False:
                    os.system(command)
                else:
                    pass
        else:
            command = f"python3 train.py --nesim-config {nesim_config} --load-weights-path {resume_data['model_path']} --nesim-apply-after-n-steps {apply_every_n_steps} --num-epochs {num_epochs} {wandb_log_arg}"
            print(f"Running command\n{command}\n")
            if dry_run_arg is False:
                os.system(command)
            else:
                pass
