import subprocess
import shlex
import argparse
import os
import datetime

parser = argparse.ArgumentParser()

parser.add_argument("--noise_scale", type=float, default=1)
parser.add_argument("--wandb_project", type=str, required=True)
parser.add_argument("--N", type=int, default=300)
parser.add_argument("--rad", type=int, default=6)

args = parser.parse_args()

assert args.wandb_project != ""

def run_python_script(command_str):
    command = shlex.split(command_str)
    result = subprocess.run(command, capture_output=True, text=True)
    print("Output:", result.stdout)
    print("Errors:", result.stderr)

ROOT = os.path.join("",
                    args.wandb_project)
#save_model_dir = os.path.join(ROOT, args.wandb_project)
try:
    os.mkdir(ROOT)
except:
    pass

now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ROOT = os.path.join(ROOT, now)

try:
    os.mkdir(ROOT)
except:
    pass


sizes = list(range(5, 200, 5)) + list(range(200, 5_001, 100))
for train_size in range(500, 10_000, 100):
    print(f"Running for N={args.N}, rad={args.rad}")
    save_model_dir = os.path.join(ROOT, f"train_size_{train_size}")
    try:
        os.mkdir(save_model_dir)
    except:
        pass
    command_str = (
        f"""python spiral.py --use_wandb True
        --noise_scale {args.noise_scale} --T 5
        --train_set_size {train_size}
        --loss bc --N {args.N} --radius_coeff {args.rad} --loss bc --n_layers 1
        --sorting outward --use_D False
        --wandb_project {f"{args.wandb_project}_{now}"}
        --save_model_dir {save_model_dir}
        --config_str
        'train_size {2 * train_size} kifele nu_ini 0.250, 1/1+nu^2, B normed adamw'
        --lr 0.01 --optimizer adam --use_encoder False
        --use_decoder False --save_models True"""
    )
    print(command_str)
    run_python_script(command_str)
#command_str = "python path/to/your_script.py --arg1 value1 --arg2 value2"