import argparse
import datetime
import itertools
import os
import pathlib
import subprocess
import textwrap

parser = argparse.ArgumentParser()
parser.add_argument("--experiment-name", type=str, required=False, default="poseidon")
parser.add_argument("--local", action="store_true")
parser.add_argument("--single-run", action="store_true")
parser.add_argument(
    "--dataset",
    type=str,
    choices=[
        "euler",
        "RB",
        "shear_flow",
    ],
    required=True,
)
args = parser.parse_args()

data_workers = 4
validation_frequency = 20
norm_type = "Samplewise"
batch_size = 64
grad_acc_steps = 1
max_samples = 500

if args.dataset == "RB":
    data_name = "rayleigh_benard"
    data_short = "RB"
    num_channels = 4
    sweep_data = [
        {
            "name": "rayleigh_benard",
            "n_steps_input": 6,
            "n_steps_output": 1,
            "max_rollout_steps": 200,
            "image_size": [128, 128],
            "prefetch_factor": 2,
            "persistent_workers": False,
        },
    ]
elif args.dataset == "euler":
    data_name = "euler_multi_quadrants_periodicBC"
    data_short = "euler"
    num_channels = 5
    sweep_data = [
        {
            "name": "euler_multi_quadrants_periodicBC",
            "n_steps_input": 6,
            "n_steps_output": 1,
            "max_rollout_steps": 100,
            "image_size": [128, 128],
            "prefetch_factor": 2,
            "persistent_workers": False,
        },
    ]
    data_workers = 10
elif args.dataset == "shear_flow":
    data_name = "shear_flow"
    data_short = "shear"
    num_channels = 4
    sweep_data = [
        {
            "name": "shear_flow",
            "n_steps_input": 6,
            "n_steps_output": 1,
            "max_rollout_steps": 200,
            "image_size": [128, 128],
            "prefetch_factor": 2,
            "persistent_workers": False,
        },
    ]


sweep = {
    "model_name": ["inv_poseidon_L"],
    "lr_scheduler": ["inv_sqrt_w_sqrt_ramps"],
    "optimizer": ["adam"],
    "optimizer_lr": [0.0001],
    "max_samples": [max_samples],
    "loss_multiplier": [1],
    "max_epoch": [200],
    "batch_size": [batch_size],
    "grad_acc_steps": [grad_acc_steps],
    "resize": [True],
    "data": sweep_data,
    "distribution": ["local"],
}

# Step 2: Create combinations of all params
keys, values = zip(*sweep.items())
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]
trainer = "defaults_mean"

experiment_name = f"{norm_type}FT_{data_short}"
date_string = datetime.datetime.now().strftime("%d-%m_%H-%M-%S")

username = os.environ["USER"]
slurm_folder = f"/mnt/home/{username}/crps_retrofitting/slurm_outputs/poseidon"
os.makedirs(slurm_folder, exist_ok=True)
current_path = pathlib.Path.cwd()
# Step 3: Generate, submit, and delete sbatch scripts
for i, params in enumerate(combinations):
    job_name = f"{params['data']['name']}_job_{i}"
    script_filename = f"{job_name}.sbatch"

    if args.local:
        experiment_name = "local_"

    ablate_string = f"{params['max_epoch']}"

    # Build the python command string
    python_cmd = textwrap.dedent(
        f"""\
        python crps_retrofitting/train.py \
        distribution=local \
        trainer={trainer} \
        trainer.enable_amp=False \
        trainer.val_frequency={validation_frequency} \
        trainer.rollout_val_frequency={validation_frequency} \
        data={params['data']['name']} \
        model={params['model_name']} \
        model.resize={params['resize']} \
        model.num_channels={num_channels} \
        model.num_out_channels={num_channels} \
        ++model.image_size='{params['data']['image_size']}' \
        trainer.grad_acc_steps={params['grad_acc_steps']} \
        data_workers={data_workers} \
        logger.wandb=True \
        auto_resume=False \
        trainer.max_epoch={params['max_epoch']} \
        optimizer={params['optimizer']} \
        optimizer.lr={params['optimizer_lr']} \
        lr_scheduler={params['lr_scheduler']} \
        trainer.loss_multiplier={params['loss_multiplier']} \
        name={experiment_name}_{ablate_string} \
        auto_resume=True \
        trainer.prediction_type="full" \
        trainer.video_validation=True \
        ++data.image_size="{params['data']['image_size']}" \
        data.module_parameters.batch_size={params['batch_size']} \
        data.module_parameters.max_samples={params['max_samples']} \
        data.module_parameters.n_steps_input={params['data']['n_steps_input']} \
        data.module_parameters.n_steps_output={params['data']['n_steps_output']} \
        data.module_parameters.inner_dset_type._target_=crps_retrofitting.data.inflated_dataset.BatchWellDataset \
        ++data.module_parameters.prefetch_factor={params['data']['prefetch_factor']} \
        ++data.module_parameters.persistent_workers={params['data']['persistent_workers']} \
        trainer.max_rollout_steps={params['data']['max_rollout_steps']} \
        ++trainer.reuse_batches=True \
        logger.wandb_project_name=Poseidon \
        """
    )
    if args.local:
        script_content = textwrap.dedent(
            f"""\
            #!/bin/bash
            export HYDRA_FULL_ERROR=1
            export NCCL_DEBUG=WARN
            export LD_LIBRARY_PATH=""
            export OMP_NUM_THREADS=1
            module load ffmpeg

            {python_cmd}
            """
        )
    else:
        script_content = textwrap.dedent(
            f"""\
            #!/bin/bash
            #SBATCH --time=7-00:00:00
            #SBATCH --partition=gpu
            #SBATCH --nodes=1
            #SBATCH --gres=gpu:1
            #SBATCH -C h100
            #SBATCH --ntasks-per-node=1
            #SBATCH --gpus-per-node=1
            #SBATCH --cpus-per-gpu=16
            #SBATCH --job-name={job_name}
            #SBATCH --output={slurm_folder}/{params['data']['name']}/{experiment_name}/%j_{job_name}_{date_string}.log

            # export OMP_NUM_THREADS=${{SLURM_CPUS_ON_NODE}}
            export OMP_NUM_THREADS=1
            export HDF5_USE_FILE_LOCKING=FALSE
            export HYDRA_FULL_ERROR=1
            export NCCL_DEBUG=WARN
            export LD_LIBRARY_PATH=""
            module load ffmpeg
            source ./venvs/walrus/bin/activate

            {python_cmd}
            """
        )

    with open(script_filename, "w") as f:
        f.write(script_content)

    if args.local:
        subprocess.run(["bash", script_filename])
    else:
        subprocess.run(["sbatch", script_filename])

    os.remove(script_filename)

    if args.single_run:
        break
        break
