#!/bin/bash

## This runs a sweep that reproduces Table 1 for iter_dist methods

# --- CONFIGURATION ---
WANDB_ENTITY="YOUR_ENTITY"
WANDB_PROJECT="YOUR_PROJECT"
RUN_DATE=$(date +%b%d)
# ---------------------

# model
model="resnet50" # vit_small, convnext_small 
batch_size=128
max_epochs=1000

# --- FIXED VARIABLES ---
sim_loss_weight=25.0
var_loss_weight=0.0
cov_loss_weight=0.0
weight_decay=1e-4
projector_type="mlp3_with_one_more_relu"

# --- SWEEP ARRAYS ---
list_of_ssl_methods=("iter_dist_imagenet")
list_of_target_distribution=("rectified_product_laplace" "rectified_gauss")
list_of_one_d_dist_loss_weight=(125.0 125.0)
list_of_lr=(0.165)
list_of_classifier_lr=(0.055)
list_of_one_d_dist_loss_choice=("sliced_wasserstein_distance")
list_of_projection_sampling_mode=("random")
list_of_swd_num_projections=(8192)
list_of_proj_dim=(2048)
list_of_mean_shift_scalar_for_rectified_gauss=(-3.0 -2.75 -2.50 -2.25 -2.0 -1.75 -1.50 -1.25 -1.0 -0.75 -0.5 -0.25 0.0 0.25 0.5 0.75 1.0)
# list_of_mean_shift_scalar_for_rectified_gauss=(0.0)

# --- MAIN LOOP ---
for ssl_method in "${list_of_ssl_methods[@]}"; do
  for dist_idx in "${!list_of_target_distribution[@]}"; do
    target_distribution=${list_of_target_distribution[$dist_idx]}
    one_d_dist_loss_weight=${list_of_one_d_dist_loss_weight[$dist_idx]}

    for idx in "${!list_of_lr[@]}"; do
      lr=${list_of_lr[$idx]}
      classifier_lr=${list_of_classifier_lr[$idx]}

      for idxs in "${!list_of_proj_dim[@]}"; do
        proj_dim=${list_of_proj_dim[$idxs]}

        for projection_sampling_mode in "${list_of_projection_sampling_mode[@]}"; do
          for one_d_dist_loss_choice in "${list_of_one_d_dist_loss_choice[@]}"; do
            for swd_num_projections in "${list_of_swd_num_projections[@]}"; do
              for mean_shift_scalar_for_rectified_gauss in "${list_of_mean_shift_scalar_for_rectified_gauss[@]}"; do

                  # 1. Define Unique Job Name
                  job_name="${model}-${target_distribution}-${mean_shift_scalar_for_rectified_gauss}"

                  # 2. Define Slurm Log Directory
                  slurm_log_dir="/your/path/to/solo-learn-iter-gauss/slurm/${RUN_DATE}/${job_name}"
                  mkdir -p "$slurm_log_dir"

                  # 3. Create a temporary SLURM script file
                  slurm_script="${slurm_log_dir}/job.slurm"
                  
                  cat <<EOT > "$slurm_script"
#!/bin/bash
#SBATCH --job-name=${job_name}
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=9
#SBATCH --gres=gpu:1
#SBATCH --time=47:59:59
#SBATCH --mem=80G
#SBATCH --account=YOUR_ACCOUNT
#SBATCH --error=${slurm_log_dir}/%j.err
#SBATCH --output=${slurm_log_dir}/%j.out

singularity exec --nv \\
  --overlay /your/path/to/imagenet100.sqf:ro \\
  --overlay /your/path/to/singularity/overlay-50G-10M-vision.ext3:ro \\
  /your/path/to/cuda12.1.1-cudnn8.9.0-devel-ubuntu22.04.2.sif \\
  /bin/bash -c "
  source /ext3/env.sh
  conda activate base
  cd /your/path/to/solo-learn-iter-gauss
  export WANDB__SERVICE_WAIT=300
  
  python3 main_pretrain.py \\
    --config-path scripts/pretrain/imagenet-100/ \\
    --config-name=${ssl_method}.yaml \\
    ++optimizer.lr=$lr \\
    ++optimizer.classifier_lr=$classifier_lr \\
    ++optimizer.weight_decay=$weight_decay \\
    ++method_kwargs.proj_hidden_dim=$proj_dim \\
    ++method_kwargs.proj_output_dim=$proj_dim \\
    ++method_kwargs.var_loss_weight=$var_loss_weight \\
    ++method_kwargs.cov_loss_weight=$cov_loss_weight \\
    ++method_kwargs.sim_loss_weight=$sim_loss_weight \\
    ++method_kwargs.projector_type=$projector_type \\
    ++method_kwargs.target_distribution=$target_distribution \\
    ++method_kwargs.one_d_dist_loss_choice=$one_d_dist_loss_choice \\
    ++method_kwargs.one_d_dist_loss_weight=$one_d_dist_loss_weight \\
    ++method_kwargs.swd_num_projections=$swd_num_projections \\
    ++method_kwargs.projection_sampling_mode=$projection_sampling_mode \\
    ++method_kwargs.mean_shift_scalar_for_rectified_gauss=$mean_shift_scalar_for_rectified_gauss \\
    ++method_kwargs.logging_interval=50 \\
    ++data.num_workers=8 \\
    ++wandb.entity=$WANDB_ENTITY \\
    ++wandb.project=$WANDB_PROJECT \\
    ++wandb.enabled=true \\
    ++auto_resume.enabled=true \\
    ++max_epochs=$max_epochs \\
    ++name=$job_name \\
    ++backbone.name=$model \\
    ++optimizer.batch_size=$batch_size
  "
EOT

                  echo "Submitting run: ${job_name}"

                  # 4. Submit chain of 2 jobs
                  job_id_1=$(sbatch --parsable "$slurm_script")
                  if [ -n "$job_id_1" ]; then
                      echo "  Submitted Job 1: $job_id_1"
                      job_id_2=$(sbatch --parsable --dependency=afterany:$job_id_1 "$slurm_script")
                      if [ -n "$job_id_2" ]; then
                          echo "  Submitted Job 2: $job_id_2 (dependency: $job_id_1)"
                      fi
                  else
                      echo "  Failed to submit Job 1 for $job_name"
                  fi

                done
              done
            done
          done
        done
      done
    done
  done
done