import argparse
import datetime
import itertools
import os
import pathlib
import subprocess
import textwrap

parser = argparse.ArgumentParser()
parser.add_argument(
    "--root_dir",
    type=str,
    required=False,
    default="Path to deterministic base checkpoint",
)
parser.add_argument("--experiment-name", type=str, required=False, default="poseidon")
parser.add_argument("--local", action="store_true")
parser.add_argument("--single-run", action="store_true")
parser.add_argument(
    "--dataset",
    type=str,
    choices=[
        "euler",
        "RB",
        "shear_flow",
    ],
    required=True,
)
args = parser.parse_args()


validation_frequency = 50
if args.dataset == "RB":
    data_name = "rayleigh_benard"
    data_short = "RB"
    num_channels = 4
    max_num_samples_dataset = 10
    sweep_data = [
        {
            "name": "rayleigh_benard",
            "n_steps_input": 6,
            "n_steps_output": 1,
            "max_rollout_steps": 200,
            "image_size": [128, 128],
        },
    ]
    checkpoint_name = "step_200"
    batch_size = 64
    grad_acc = 1
    max_samples = 500
    max_epoch = 100
elif args.dataset == "euler":
    data_name = "euler_multi_quadrants_periodicBC"
    data_short = "euler"
    num_channels = 5
    max_num_samples_dataset = 6
    image_size = [128, 128]
    checkpoint_name = "step_200"
    batch_size = 64
    grad_acc = 1
    max_samples = 500
    data_workers = 10
    max_epoch = 100
    validation_frequency = 100
    sweep_data = [
        {
            "name": "euler_multi_quadrants_periodicBC",
            "n_steps_input": 6,
            "n_steps_output": 1,
            "max_rollout_steps": 100,
            "image_size": [128, 128],
            "prefetch_factor": 2,
            "persistent_workers": True,
        },
    ]
elif args.dataset == "shear_flow":
    data_name = "shear_flow"
    data_short = "shear"
    num_channels = 4
    max_num_samples_dataset = 10
    image_size = [128, 128]
    sweep_data = [
        {
            "name": "shear_flow",
            "n_steps_input": 6,
            "n_steps_output": 1,
            "max_rollout_steps": 200,
            "image_size": image_size,
        },
    ]
    batch_size = 64
    grad_acc = 1
    max_samples = 500
    max_epoch = 100

init_epoch = checkpoint_name.split("_")[-1]
coalesced_checkpoint_path = (
    f"{args.root_dir}/checkpoints/{checkpoint_name}/coalesced.pth"
)
config_override = f"{args.root_dir}/extended_config.yaml"

sweep = {
    "model_name": ["inv_poseidon_L_FT"],
    "lr_scheduler": ["inv_sqrt_w_sqrt_ramps"],
    "optimizer": ["adam"],
    "optimizer_lr": [0.00005],
    "max_samples": [max_samples],
    "loss_multiplier": [1],
    "max_epoch": [max_epoch],
    "batch_size": [batch_size],
    "grad_acc_steps": [grad_acc],
    "resize": [True],
    "data": sweep_data,
    "distribution": ["local"],
    "coalesced_checkpoint_path": [coalesced_checkpoint_path],
    "config_override": [config_override],
    # Trainer
    "validation_ensemble_size": [1],
    "ensemble_sizes_to_save": ["[1]"],
    "weight_decay": [0.0001],
}

# Step 2: Create combinations of all params
keys, values = zip(*sweep.items())
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]

experiment_name = f"FT_{data_short}"
date_string = datetime.datetime.now().strftime("%d-%m_%H-%M-%S")

username = os.environ["USER"]
slurm_folder = f"/mnt/home/{username}/crps_retrofitting/slurm_outputs/poseidon"
os.makedirs(slurm_folder, exist_ok=True)
current_path = pathlib.Path.cwd()
# Step 3: Generate, submit, and delete sbatch scripts
for i, params in enumerate(combinations):
    warmup_epochs = min(int(0.1 * params["max_epoch"]), 10)
    cooldown_epochs = min(int(0.1 * params["max_epoch"]), 10)
    job_name = f"{params['data']['name']}_job_{i}"
    script_filename = f"{job_name}.sbatch"

    ablate_string = f"{params['max_epoch']}_ckpt{init_epoch}_MS{params['max_samples']}_{params['optimizer_lr']}"

    # Build the python command string
    python_cmd = textwrap.dedent(
        f"""\
        python crps_retrofitting/train.py \
        distribution=local \
        finetune=True \
        frozen_components=[] \
        trainer=defaults_mean \
        trainer.enable_amp=False \
        ++trainer.max_spectral_val_samples=5 \
        data={params['data']['name']} \
        model={params['model_name']} \
        model.resize={params['resize']} \
        model.num_channels={num_channels} \
        model.num_out_channels={num_channels} \
        ++model.image_size='{params['data']['image_size']}' \
        trainer.grad_acc_steps={params['grad_acc_steps']} \
        data_workers=10 \
        logger.wandb=True \
        trainer.max_epoch={params['max_epoch']} \
        optimizer={params['optimizer']} \
        optimizer.lr={params['optimizer_lr']} \
        lr_scheduler={params['lr_scheduler']} \
        lr_scheduler.warmup_epochs={warmup_epochs} \
        lr_scheduler.cooldown_epochs={cooldown_epochs} \
        trainer.loss_multiplier={params['loss_multiplier']} \
        name={experiment_name}_{ablate_string} \
        auto_resume=False \
        trainer.prediction_type="full" \
        ++data.image_size="{params['data']['image_size']}" \
        data.module_parameters.batch_size={params['batch_size']} \
        data.module_parameters.max_samples={params['max_samples']} \
        data.module_parameters.n_steps_input={params['data']['n_steps_input']} \
        data.module_parameters.n_steps_output={params['data']['n_steps_output']} \
        data.module_parameters.inner_dset_type._target_=crps_retrofitting.data.inflated_dataset.BatchWellDataset \
        trainer.max_rollout_steps={params['data']['max_rollout_steps']} \
        ++trainer.reuse_batches=True \
        trainer.val_frequency={validation_frequency} \
        trainer.rollout_val_frequency={validation_frequency} \
        logger.wandb_project_name=Poseidon \
        checkpoint.coalesced_checkpoint_path='{params['coalesced_checkpoint_path']}' \
        config_override='{params['config_override']}' \
        ++optimizer.new_params_lr={params['optimizer_lr']} \
        ++optimizer.new_params_kwargs.weight_decay={params['weight_decay']} \
        ++optimizer.new_params_kwargs.eps=1e-10 \
        """
    )
    script_content = textwrap.dedent(
        f"""\
        #!/bin/bash
        #SBATCH --time=7-00:00:00
        #SBATCH --partition=gpu
        #SBATCH --nodes=1
        #SBATCH --gres=gpu:1
        #SBATCH -C h100|a100-80gb
        #SBATCH --ntasks-per-node=1
        #SBATCH --gpus-per-node=1
        #SBATCH --cpus-per-gpu=16
        #SBATCH --job-name={job_name}
        #SBATCH --output={slurm_folder}/{params['data']['name']}/{experiment_name}/%j_{job_name}_{date_string}.log

        export OMP_NUM_THREADS=${{SLURM_CPUS_ON_NODE}}
        export HDF5_USE_FILE_LOCKING=FALSE
        export HYDRA_FULL_ERROR=1
        export NCCL_DEBUG=WARN
        export LD_LIBRARY_PATH=""
        module load ffmpeg
        source ./venvs/walrus/bin/activate

        {python_cmd}
        """
    )

    with open(script_filename, "w") as f:
        f.write(script_content)

    if args.local:
        subprocess.run(["bash", script_filename])
    else:
        subprocess.run(["sbatch", script_filename])

    os.remove(script_filename)

    if args.single_run:
        break
