import os
import json
from datetime import datetime


def generate_slurm_script(config_dir, output_dir="slurm_scripts/big_batch"):
    os.makedirs(output_dir, exist_ok=True)
    # Get current date and time
    now = datetime.now()

    # Format as string: YYYY-MM-DD HH:MM
    formatted = now.strftime("%d%m_%H%M")
    script_path = os.path.join(output_dir, f"all_exp_cifar100_{formatted}.sh")
    print(formatted)

    dataset = "cifar100"
    increase_factor = 1.0
    valid_teal_types = ["one_time", "log_iterative"]

    cl_algorithms = ["er_ace"]
    sel_strategies = ["probcover"]
    feature_spaces = ["dino"]
    buffer_sizes = [2000]
    nvidia = False
    order = None
    configs = []

    config_files = [os.path.join(config_dir, f) for f in os.listdir(config_dir) if f.endswith(".json")]
    if not config_files:
        raise ValueError(f"No JSON config files found in {config_dir}")

    for config_path in config_files:
        with open(config_path, "r") as f:
            config_data = json.load(f)

        num_experiences = config_data.get("num_experiences", 10)
        teal_type = "one_time"
        integrated_features = config_data.get("integrated_features", "True")
        concatenated = config_data.get("concatenated", "False")
        if teal_type not in valid_teal_types:
            raise ValueError(f"Invalid 'teal_type' in {config_path}: {teal_type} (must be one of {valid_teal_types})")

        for cl_algorithm in cl_algorithms:
            for sel_strategy in sel_strategies:
                if sel_strategy == "random":
                    configs.append((dataset, num_experiences, cl_algorithm, sel_strategy, "model_based", "", 0,
                                    teal_type, integrated_features, concatenated, 'median_cosine', 'median_cosine',
                                    '1nn', 'knn', nvidia, order))
                    continue

                for alpha in [0, 1]:
                    if alpha == 0:
                        for features_ss in ["vicreg"]:
                            configs.append((dataset, num_experiences, cl_algorithm, sel_strategy, features_ss, "", alpha, teal_type, integrated_features, concatenated, 'median_cosine', 'median_cosine', '1nn', '1nn', nvidia, order))
                    else:
                        pass
                        # configs.append((dataset, num_experiences, cl_algorithm, sel_strategy, 'simclr', "", alpha,
                        #                 teal_type, integrated_features, concatenated, 'median_cosine', 'median_cosine',
                        #                 '1nn', 'knn', nvidia, order))

                for features_ss in ["vicreg"]:
                    configs.append((dataset, num_experiences, cl_algorithm, sel_strategy, features_ss,
                                    "ratio_median_knn_density_k_1", 0.5, teal_type, integrated_features, concatenated,
                                    'median_cosine', 'median_cosine', '1nn', 'knn', nvidia, order))
    total_configs = len(configs)
    variations_per_config = len(buffer_sizes) * 5  # 5 seeds per buffer size
    total_tasks = total_configs * variations_per_config

    with open(script_path, "w") as f:
        f.write(f"""#!/bin/bash
#SBATCH --mem=10g
#SBATCH -c 2
#SBATCH --time=0-9
#SBATCH --gres gg:g0:1
#SBATCH --mail-type=END,FAIL,TIME_LIMIT
#SBATCH --exclude=dumfries-009,dumfries-008,dumfries-001,dumfries-002,dumfries-003,dumfries-004,dumfries-005,dumfries-006,dumfries-007,dumfries-010
#SBATCH --array=0-{total_tasks - 1}


echo "Starting job with ID: $SLURM_JOB_ID, Task ID: $SLURM_ARRAY_TASK_ID"
# Optional context for debugging
echo "Host: $(hostname)"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
nvidia-smi -L || true

# ---- GPU preflight (auto-requeue on bad/wedged nodes) ----
bin/python3.9 - <<'PY'
import sys
try:
    import torch
except Exception as e:
    print("Failed to import torch:", repr(e))
    sys.exit(1)

print("torch:", getattr(torch, "__version__", "?"), "cuda_rt:", getattr(torch.version, "cuda", "?"))
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
    print("No visible CUDA device")
    sys.exit(1)
try:
    x = torch.zeros(1, device="cuda")
    torch.cuda.synchronize()
    print("GPU 0:", torch.cuda.get_device_name(0))
    print("CUDA smoke test: OK")
except Exception as e:
    print("Preflight exception:", repr(e))
    sys.exit(2)
PY
preflight_status=$?
if [[ $preflight_status -ne 0 ]]; then
    echo "GPU preflight failed (code $preflight_status) → requeue this task"
    scontrol requeue $SLURM_JOB_ID
    exit 0
fi
# ----------------------------------------------------------

# Define buffer sizes
buffer_sizes=({' '.join(map(str, buffer_sizes))})

configs=(
""")

        for cfg in configs:
            cfg_str = ",".join(str(x) for x in cfg)
            f.write(f'    "{cfg_str}"\n')

        f.write(f""")

total_configs=${{#configs[@]}}
variations_per_config={variations_per_config}

config_index=$(( SLURM_ARRAY_TASK_ID / variations_per_config ))
variation_index=$(( SLURM_ARRAY_TASK_ID % variations_per_config ))

cfg=${{configs[$config_index]}}
IFS=',' read -r dataset num_experiences algorithm sel_strategy features_ss weight_method alpha teal_type integrated_features concatenated sigma_mb sigma_ss delta_mb delta_ss nvidia order <<< "$cfg"

# Calculate buffer and seed from variation_index
buffer_index=$(( variation_index / 5 ))
seed=$(( variation_index % 5 ))
buffer_size=${{buffer_sizes[$buffer_index]}}

echo "Running config: $cfg, seed: $seed, buffer: $buffer_size"

bin/python3.9 \\
    CL/MERS/main.py \\
    --dataset $dataset \\
    --num_experiences $num_experiences \\
    --algorithm $algorithm \\
    --sel_strategy $sel_strategy \\
    --features_type $features_ss \\
    --teal_type $teal_type \\
    --buffer $buffer_size \\
    --integrated_features $integrated_features \\
    --seed $seed \\
    --batch_id $SLURM_JOB_ID \\
    $( [[ "$alpha" != "" ]] && echo "--alpha $alpha" ) \\
    --delta 0.28 \\
    --increase_factor 1.0 \\
    --concatenated $concatenated \\
    --exp_name dino_nvidia \\
    --sigma_mb $sigma_mb \\
    --sigma_ss $sigma_ss \\
    --delta_mb $delta_mb \\
    --delta_ss $delta_ss \\
    $( [[ "$nvidia" == "True" ]] && echo "--nvidia" ) \\
    $( [[ "$order" != "" && "$order" != "None" ]] && echo "--order $order" ) \\
    $( [[ "$weight_method" != "" ]] && echo "--weight_method $weight_method" ) 

""")

    os.chmod(script_path, 0o755)
    print(
        f"Generated {script_path} with {total_tasks} tasks ({total_configs} configs × {variations_per_config} variations).")
    print(f"Buffer sizes: {buffer_sizes}")


if __name__ == "__main__":
    generate_slurm_script(config_dir="../configuration_files/cifar100")