import argparse
import datetime
import itertools
import os
import pathlib
import subprocess
import textwrap

parser = argparse.ArgumentParser()
parser.add_argument(
    "--root_dir",
    type=str,
    required=False,
    default="Path to deterministic base checkpoint",
)
parser.add_argument(
    "--config_file",
    type=str,
    required=False,
    default="Config file for deterministic base checkpoint",
)
parser.add_argument("--experiment_name", type=str, required=False, default="Smallrus")
parser.add_argument("--experiment", type=str, required=False, default="no_overrides")
parser.add_argument("--local", action="store_true")
parser.add_argument("--single-run", action="store_true")
parser.add_argument(
    "--dataset",
    type=str,
    choices=[
        "euler",
        "RB",
        "shear_flow",
    ],
    required=True,
)
args = parser.parse_args()

if args.dataset == "RB":
    data_name = "rayleigh_benard"
    batch_size = 8
    grad_acc_steps = 1
    max_rollout_steps = 200
    max_epoch = 50
    max_samples = 2000
    checkpoint_name = "step_200"
elif args.dataset == "euler":
    data_name = "euler_multi_quadrants_periodicBC"
    batch_size = 8
    grad_acc_steps = 1
    max_rollout_steps = 100
    max_epoch = 50
    max_samples = 2000
    checkpoint_name = "step_200"
elif args.dataset == "shear_flow":
    data_name = "shear_flow"
    batch_size = 8
    grad_acc_steps = 1
    max_rollout_steps = 200
    max_epoch = 50
    max_samples = 2000
    checkpoint_name = "step_200"
else:
    raise NotImplementedError(f"Dataset {args.dataset} not implemented.")


coalesced_checkpoint_path = (
    f"{args.root_dir}/checkpoints/{checkpoint_name}/coalesced.pth"
)

sweep = {
    "experiment": [args.experiment],
    "lr_scheduler": ["inv_sqrt_w_sqrt_ramps"],
    "optimizer_lr": [0.001],
    "finetune_lr": [0.00005],
    "max_samples": [max_samples],
    "loss_multiplier": [100],
    "max_epoch": [max_epoch],
    "batch_size": [batch_size],
    "grad_acc_steps": [grad_acc_steps],
    "hidden_dim": [1088],
    "processor_blocks": [8],
    "data": [
        {
            "name": data_name,
            "min_dt_stride": 1,
            "max_dt_stride": 1,
        },
    ],
    "distribution": ["local"],
    "coalesced_checkpoint_path": [coalesced_checkpoint_path],
    "config_override": [args.config_file],
}

# Step 2: Create combinations of all params
keys, values = zip(*sweep.items())
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]

experiment_name = args.experiment_name
date_string = datetime.datetime.now().strftime("%d-%m_%H-%M-%S")


username = os.environ["USER"]
slurm_folder = f"/mnt/home/{username}/crps_retrofitting/slurm_outputs"
os.makedirs(slurm_folder, exist_ok=True)
current_path = pathlib.Path.cwd()
# Step 3: Generate, submit, and delete sbatch scripts
for i, params in enumerate(combinations):
    warmup_epochs = min(int(0.1 * params["max_epoch"]), 10)
    cooldown_epochs = min(int(0.1 * params["max_epoch"]), 10)
    job_name = f"{params['data']['name']}_job_{i}"
    script_filename = f"{job_name}.sbatch"

    ablate_string = (
        f"FT_{params['finetune_lr']}_E{params['max_epoch']}_MS{params['max_samples']}"
    )
    # Build the python command string
    python_cmd = textwrap.dedent(
        f"""\
        python crps_retrofitting/train.py \
            model.gradient_checkpointing_freq=0 \
            model.override_dimensionality=0 \
            auto_resume=True \
            checkpoint=defaults \
            name={experiment_name}_{ablate_string} \
            distribution={params['distribution']} \
            trainer=defaults \
            trainer.enable_amp=False \
            trainer.grad_acc_steps={params['grad_acc_steps']} \
            trainer.max_epoch={params['max_epoch']} \
            trainer.log_interval=200 \
            trainer.clip_gradient=10 \
            trainer.short_validation_length=20 \
            trainer.max_rollout_steps={max_rollout_steps} \
            trainer.val_frequency=20 \
            trainer.rollout_val_frequency=20 \
            trainer.prediction_type=delta \
            +trainer.skip_spectral_metrics=True \
            trainer.video_validation=True \
            data={params['data']['name']} \
            data.module_parameters.batch_size={params['batch_size']} \
            data.module_parameters.max_samples={params['max_samples']} \
            data.module_parameters.n_steps_input=6 \
            data.module_parameters.n_steps_output=1 \
            data.module_parameters.min_dt_stride={params['data']['min_dt_stride']} \
            data.module_parameters.max_dt_stride={params['data']['max_dt_stride']} \
            optimizer=adam \
            optimizer.lr={params["optimizer_lr"]} \
            model/processor/space_mixing=full_spatial_attention \
            model.projection_dim=48 \
            model.intermediate_dim=352 \
            model.hidden_dim={params["hidden_dim"]} \
            model.groups=16 \
            model.drop_path=0.05 \
            model.processor_blocks={params["processor_blocks"]} \
            model.processor.space_mixing.num_heads=16 \
            model.processor.time_mixing.num_heads=16 \
            model.causal_in_time=True \
            model.jitter_patches=True \
            +model.use_periodic_fixed_jitter=True \
            +model.input_field_drop=0 \
            lr_scheduler={params['lr_scheduler']} \
            lr_scheduler.warmup_epochs={warmup_epochs} \
            lr_scheduler.cooldown_epochs={cooldown_epochs} \
            data_workers=10 \
            logger.wandb_project_name=CRPS \
            checkpoint.coalesced_checkpoint_path='{coalesced_checkpoint_path}'\
            config_override='{args.config_file}'\
            experiment={params['experiment']} \
            checkpoint.checkpoint_frequency=10 \
            ++finetune_lr={params['finetune_lr']} \
            finetune=True \
        """
    )
    script_content = textwrap.dedent(
        f"""\
        #!/bin/bash
        #SBATCH --time=7-00:00:00
        #SBATCH --partition=gpu
        #SBATCH --nodes=1
        #SBATCH --gres=gpu:1
        #SBATCH -C h100 
        #SBATCH --ntasks-per-node=1
        #SBATCH --gpus-per-node=1
        #SBATCH --cpus-per-gpu=16
        #SBATCH --job-name={job_name}
        #SBATCH --output={slurm_folder}/{params['data']['name']}/{experiment_name}/%j_{job_name}_{date_string}.log

        export OMP_NUM_THREADS=${{SLURM_CPUS_ON_NODE}}
        export HDF5_USE_FILE_LOCKING=FALSE
        export HYDRA_FULL_ERROR=1
        export NCCL_DEBUG=WARN
        export LD_LIBRARY_PATH=""
        module load ffmpeg
        source ./venvs/walrus/bin/activate

        {python_cmd}
        """
    )

    with open(script_filename, "w") as f:
        f.write(script_content)

    if args.local:
        subprocess.run(["bash", script_filename])
    else:
        subprocess.run(["sbatch", script_filename])

    os.remove(script_filename)

    if args.single_run:
        break
