# sweep_generate.py

import os
import sys
import subprocess
from itertools import product

# ==== Fixed arguments ====
DATASET_SIZE = 4
BATCH_SIZE = 64
LABELS = "[1,6]"
DATASET_NAME = "ConstructiveSweep"
CHECKPOINT_PATH = "./output_dir/checkpoint-cond-699.pth"
TIMESTEP_SIZE = 20
SAMPLE_SEED = 42
GEN_PARADIGM = "constructive"
N_MOD = 2
N_U = 2
N_U_THETA = 5     # Change as needed!
PERCENTILE = 0.95
# T_VECTORS_METHOD = "./generation_arguments/T_vectors/constructive_test_arguments"
T_VECTORS_METHOD = "random"
dagul = 1 # Fixed for simplicity
dagtheta = 1 # Fixed for simplicity
alpha = 1 # Fixed for simplicity
beta = 1 # Fixed for simplicity
rho = 1

# ==== Hyperparameter sweep values ====
SWEEP_VALUES = [0.01, 0.1, 1.0]  # Small grid for testing, expand as needed
DAGTHETA_VALUES = [0.5, 1.0]
# SWEEP_VALUES = [0.0, 1.0]  # Flipping Dags on and off

# param_list = list(product(SWEEP_VALUES, repeat=4)) # repeat = No. of parameters to sweep

param_list = list(product(
    SWEEP_VALUES,   # rho0
    SWEEP_VALUES,   # rho1
    SWEEP_VALUES,   # eta
    SWEEP_VALUES,   # rho_theta
    DAGTHETA_VALUES, # dagtheta0
    DAGTHETA_VALUES  # dagtheta1
))

if len(sys.argv) < 2:
    print(f"Usage: python {sys.argv[0]} INDEX")
    sys.exit(1)

idx = int(sys.argv[1])

if idx < 0 or idx >= len(param_list):
    print(f"Index {idx} out of range (max {len(param_list)-1})")
    sys.exit(1)

# rho, eta, rho_theta = param_list[idx]
# dagtheta00, dagtheta01, dagtheta02, dagtheta03, dagtheta04, dagtheta10, dagtheta11, dagtheta12, dagtheta13, dagtheta14 = param_list[idx]
rho0, rho1, eta, rho_theta, dagtheta0, dagtheta1 = param_list[idx]


# Format args
rho_arg = f"[{rho0},{rho1}]"
dag_ul = f"[[{dagul},{0.0}],[{0.0},{dagul}]]"
eta_arg = f"[{eta},{eta},{eta},{eta},{eta}]"
rho_theta_arg = f"[{rho_theta},{rho_theta},{rho_theta},{rho_theta},{rho_theta}]"
dag_theta = f"[[{dagtheta0},{dagtheta0},{dagtheta0},{dagtheta0},{dagtheta0}],[{dagtheta1},{dagtheta1},{dagtheta1},{dagtheta1},{dagtheta1}]]"
# dag_theta = f"[[{dagtheta00},{dagtheta01},{dagtheta02},{dagtheta03},{dagtheta04}],[{dagtheta10},{dagtheta11},{dagtheta12},{dagtheta13},{dagtheta14}]]"

# Output folder (unique per combination)
# outdir = f"./output_dir/datasets/hyper_sweep/rho_{rho}/dagul_{dagul}/eta_{eta}/rhotheta_{rho_theta}/dagtheta_{dagtheta}/alpha_{alpha}/beta_{beta}"
outdir = f"./output_dir/datasets/hyper_sweep/rho0_{rho0}/rho1_{rho1}/eta_{eta}/rhotheta_{rho_theta}/dagtheta0_{dagtheta0}/dagtheta1_{dagtheta1}"


# dagtheta_str = ''.join(str(int(x)) for x in [
#     dagtheta00, dagtheta01, dagtheta02, dagtheta03, dagtheta04,
#     dagtheta10, dagtheta11, dagtheta12, dagtheta13, dagtheta14
# ])
# outdir = f"./output_dir/datasets/hyper_sweep/DAGtheta_{dagtheta_str}"

os.makedirs(outdir, exist_ok=True)

# Skip if output exists (optional)
result_file = os.path.join(outdir, f"{DATASET_NAME}_0.npz")
if os.path.exists(result_file):
    print(f"Skipping {outdir} (already exists)")
    sys.exit(0)

args = [
    "python", "dataset_generation.py",
    "--dataset_size", str(DATASET_SIZE),
    "--batch_size", str(BATCH_SIZE),
    "--labels", LABELS,
    "--dataset_name", DATASET_NAME,
    "--dataset_loc", outdir,
    "--checkpoint_path", CHECKPOINT_PATH,
    "--timestep_size", str(TIMESTEP_SIZE),
    "--sample_seed", str(SAMPLE_SEED),
    "--debug", "0",
    "--generation_paradigm", GEN_PARADIGM,
    "--calculate_MI",
    "--N_mod", str(N_MOD),
    "--N_u", str(N_U),
    "--rho_arg", rho_arg,
    "--DAG_ul", dag_ul,
    "--N_u_theta", str(N_U_THETA),
    "--eta_arg", eta_arg,
    "--rho_theta_arg", rho_theta_arg,
    "--DAG_theta", dag_theta,
    "--T_vectors_method", T_VECTORS_METHOD,
    "--alpha", str(alpha),
    "--beta", str(beta),
    "--percentile_to_align", str(PERCENTILE)
]

stdout_log = os.path.join(outdir, "stdout.log")
stderr_log = os.path.join(outdir, "stderr.log")

print(f"Running combo {idx+1}/{len(param_list)}: {outdir}")
with open(stdout_log, "w") as fout, open(stderr_log, "w") as ferr:
    subprocess.run(args, stdout=fout, stderr=ferr)
