import argparse
import datetime
import itertools
import os
import pathlib
import subprocess
import textwrap

parser = argparse.ArgumentParser()
parser.add_argument("--experiment_name", type=str, required=False, default="Smallrus")
parser.add_argument("--experiment", type=str, required=False, default="no_overrides")
parser.add_argument("--local", action="store_true")
parser.add_argument("--single-run", action="store_true")
parser.add_argument(
    "--dataset",
    type=str,
    choices=[
        "euler",
        "RB",
        "shear_flow",
    ],
    required=True,
)
args = parser.parse_args()

if args.dataset == "RB":
    data_name = "rayleigh_benard"
    batch_size = 8
    grad_acc_steps = 1
    max_rollout_steps = 200
    max_epoch = 200
elif args.dataset == "euler":
    data_name = "euler_multi_quadrants_periodicBC"
    batch_size = 8
    grad_acc_steps = 4
    max_rollout_steps = 100
    max_epoch = 200
elif args.dataset == "shear_flow":
    data_name = "shear_flow"
    batch_size = 8
    grad_acc_steps = 1
    max_rollout_steps = 100
    max_epoch = 200
else:
    raise NotImplementedError(f"Dataset {args.dataset} not implemented.")

sweep = {
    "experiment": [args.experiment],
    "lr_scheduler": ["inv_sqrt_w_sqrt_ramps"],
    "optimizer_lr": [0.001],
    "max_samples": [2000],
    "loss_multiplier": [100],
    "max_epoch": [max_epoch],
    "batch_size": [batch_size],
    "grad_acc_steps": [grad_acc_steps],
    "hidden_dim": [1088],
    "processor_blocks": [8],
    "data": [
        {
            "name": data_name,
            "min_dt_stride": 1,
            "max_dt_stride": 1,
        },
    ],
    "distribution": ["local"],
    "coalesced_checkpoint_path": [""],
    "config_override": [""],
}

# Step 2: Create combinations of all params
keys, values = zip(*sweep.items())
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]

experiment_name = args.experiment_name
date_string = datetime.datetime.now().strftime("%d-%m_%H-%M-%S")


username = os.environ["USER"]
slurm_folder = f"/mnt/home/{username}/crps_retrofitting/slurm_outputs"
os.makedirs(slurm_folder, exist_ok=True)
current_path = pathlib.Path.cwd()
# Step 3: Generate, submit, and delete sbatch scripts
for i, params in enumerate(combinations):
    job_name = f"{params['data']['name']}_job_{i}"
    script_filename = f"{job_name}.sbatch"

    if args.local:
        experiment_name = "local_" + experiment_name

    # ablate_string = f"ps_{params['patch_size']}"

    # Build the python command string
    python_cmd = textwrap.dedent(
        f"""\
        python crps_retrofitting/train.py \
            model.gradient_checkpointing_freq=0 \
            model.override_dimensionality=0 \
            auto_resume=True \
            checkpoint=defaults \
            name={experiment_name}_{params['max_epoch']} \
            distribution={params['distribution']} \
            trainer=defaults \
            trainer.enable_amp=False \
            trainer.grad_acc_steps={params['grad_acc_steps']} \
            trainer.max_epoch={params['max_epoch']} \
            trainer.log_interval=200 \
            trainer.clip_gradient=10 \
            trainer.short_validation_length=20 \
            trainer.max_rollout_steps={max_rollout_steps} \
            trainer.val_frequency=10 \
            trainer.rollout_val_frequency=10 \
            trainer.prediction_type=delta \
            +trainer.skip_spectral_metrics=True \
            trainer.video_validation=True \
            data={params['data']['name']} \
            data.module_parameters.batch_size={params['batch_size']} \
            data.module_parameters.max_samples={params['max_samples']} \
            data.module_parameters.n_steps_input=6 \
            data.module_parameters.n_steps_output=1 \
            data.module_parameters.min_dt_stride={params['data']['min_dt_stride']} \
            data.module_parameters.max_dt_stride={params['data']['max_dt_stride']} \
            optimizer=adam \
            optimizer.lr={params["optimizer_lr"]} \
            model/processor/space_mixing=full_spatial_attention \
            model.projection_dim=48 \
            model.intermediate_dim=352 \
            model.hidden_dim={params["hidden_dim"]} \
            model.groups=16 \
            model.drop_path=0.05 \
            model.processor_blocks={params["processor_blocks"]} \
            model.processor.space_mixing.num_heads=16 \
            model.processor.time_mixing.num_heads=16 \
            model.causal_in_time=True \
            model.jitter_patches=True \
            +model.use_periodic_fixed_jitter=True \
            +model.input_field_drop=0 \
            lr_scheduler={params['lr_scheduler']} \
            data_workers=10 \
            logger.wandb_project_name=CRPS \
            experiment={params['experiment']} \
            checkpoint.checkpoint_frequency=10 \
        """
    )

    if args.local:
        script_content = textwrap.dedent(
            f"""\
            #!/bin/bash
            export HYDRA_FULL_ERROR=1
            export NCCL_DEBUG=WARN
            export LD_LIBRARY_PATH=""
            module load ffmpeg

            {python_cmd}
            """
        )
    else:
        script_content = textwrap.dedent(
            f"""\
            #!/bin/bash
            #SBATCH --time=7-00:00:00
            #SBATCH --partition=gpu
            #SBATCH --nodes=1
            #SBATCH --gres=gpu:1
            #SBATCH -C h100 
            #SBATCH --ntasks-per-node=1
            #SBATCH --gpus-per-node=1
            #SBATCH --cpus-per-gpu=16
            #SBATCH --job-name={job_name}
            #SBATCH --output={slurm_folder}/{params['data']['name']}/{experiment_name}/%j_{job_name}_{date_string}.log

            export OMP_NUM_THREADS=${{SLURM_CPUS_ON_NODE}}
            export HDF5_USE_FILE_LOCKING=FALSE
            export HYDRA_FULL_ERROR=1
            export NCCL_DEBUG=WARN
            export LD_LIBRARY_PATH=""
            module load ffmpeg
            source ./venvs/walrus/bin/activate

            {python_cmd}
            """
        )

    with open(script_filename, "w") as f:
        f.write(script_content)

    if args.local:
        subprocess.run(["bash", script_filename])
    else:
        subprocess.run(["sbatch", script_filename])

    os.remove(script_filename)

    if args.single_run:
        break
