import argparse
import ast
import datetime
import itertools
import os
import pathlib
import subprocess
import textwrap

parser = argparse.ArgumentParser()
parser.add_argument(
    "--root_dir",
    type=str,
    required=False,
    default="Path to deterministic base checkpoint",
)
parser.add_argument("--experiment_name", type=str, required=False, default="CRPS")
parser.add_argument(
    "--dataset",
    type=str,
    choices=[
        "RB",
        "trl",
        "viscoelastic",
    ],
    required=True,
)
parser.add_argument("--local", action="store_true")
parser.add_argument("--single-run", action="store_true")
args = parser.parse_args()

loss_fn = "CRPS"
loss_fn_target = "crps_retrofitting.metrics.crps.CRPS"

checkpoint_name = "halfwalrus_step160.pt"

if "RB" in args.dataset:
    data_name = "rayleigh_benard"
    max_num_samples_dataset = 10
elif "trl" in args.dataset:
    data_name = "turbulent_radiative_layer_2D"
    max_num_samples_dataset = 10
elif "viscoelastic" in args.dataset:
    data_name = "viscoelastic_instability"
    max_num_samples_dataset = 6
else:
    raise ValueError(f"Invalid dataset {args.dataset}")
init_step = checkpoint_name.split("_")[-1]
coalesced_checkpoint_path = f"{args.root_dir}/{checkpoint_name}"

config_override = f"{args.root_dir}/extended_config.yaml"

sweep = {
    "experiment": [args.experiment],
    "lr_scheduler": ["inv_sqrt_w_sqrt_ramps"],
    "optimizer_lr": [0.0005],
    "common_optimizer_lr": [0.0001],
    "max_samples": [2000],
    "loss_multiplier": [100],
    "max_epoch": [50],
    "hidden_dim": [1088],
    "processor_blocks": [30],
    "data": [
        {
            "name": data_name,
            "min_dt_stride": 1,
            "max_dt_stride": 1,
        },
    ],
    "distribution": ["local"],
    "coalesced_checkpoint_path": [coalesced_checkpoint_path],
    "config_override": [config_override],
    "batch_num_samples_grad_acc": [
        (1, 2, 4),
    ],
    "noise_mode": ["global"],
    "noise_dim": [32],
    "mlp_layers": [2],
    "noise_layernorm": [True],
    # Trainer
    "validation_ensemble_size": [16],
    "ensemble_sizes_to_save": ["[1, 2, 4, 8, 16]"],
    # "noise_blocks": [
    #     "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]"
    # ],
    "noise_blocks": ["[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28]"],
    "max_num_samples": [max_num_samples_dataset],
}

# Step 2: Create combinations of all params
keys, values = zip(*sweep.items())
# combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]
combinations = []
for combo in itertools.product(*values):
    params = dict(zip(keys, combo))
    # unpack the tuple into separate keys
    batch_size, num_samples, grad_acc_steps = params.pop("batch_num_samples_grad_acc")
    params["batch_size"] = batch_size
    params["num_samples"] = num_samples
    params["grad_acc_steps"] = grad_acc_steps
    processor_noise_cond_dim = params.pop("noise_dim")[0]
    params["processor_noise_cond_dim"] = processor_noise_cond_dim
    if params not in combinations:
        combinations.append(params)
    if params not in combinations:
        combinations.append(params)

experiment_name = f"{loss_fn}HalfWalrus_" + args.dataset
date_string = datetime.datetime.now().strftime("%d-%m_%H-%M-%S")

username = os.environ["USER"]
slurm_folder = f"/mnt/home/{username}/crps_retrofitting/slurm_outputs"
os.makedirs(slurm_folder, exist_ok=True)
current_path = pathlib.Path.cwd()
# Step 3: Generate, submit, and delete sbatch scripts
for i, params in enumerate(combinations):
    warmup_epochs = min(int(0.1 * params["max_epoch"]), 10)
    cooldown_epochs = min(int(0.1 * params["max_epoch"]), 10)
    job_name = f"{params['data']['name']}_job_{i}"
    script_filename = f"{job_name}.sbatch"

    crps_experiment = params["experiment"].split("_")[-1]
    noise_blocks = ast.literal_eval(params["noise_blocks"])
    ablate_string = f"{params['noise_mode']}_MLP{params['mlp_layers']}LN_P{params['processor_noise_cond_dim']}_L{len(noise_blocks)}_B{params['batch_size']}_NS{params['num_samples']}_{params['max_epoch']}"
    
    if "common_optimizer_lr" in params:
        ablate_string += f"_{params['common_optimizer_lr']}"
    if args.local:
        ablate_string = "local_" + ablate_string

    # Build the python command string
    python_cmd = textwrap.dedent(
        f"""\
        python crps_retrofitting/train.py \
            model.gradient_checkpointing_freq=0 \
            model.override_dimensionality=0 \
            auto_resume=True \
            checkpoint=defaults \
            name='{experiment_name}_{ablate_string}' \
            distribution={params['distribution']} \
            trainer=defaults \
            trainer.enable_amp=False \
            trainer.grad_acc_steps={params['grad_acc_steps']} \
            trainer.max_epoch={params['max_epoch']} \
            trainer.log_interval=200 \
            trainer.clip_gradient=10 \
            trainer.short_validation_length=20 \
            trainer.max_rollout_steps=200 \
            trainer.val_frequency=50 \
            trainer.rollout_val_frequency=50 \
            trainer.prediction_type=delta \
            +trainer.skip_spectral_metrics=True \
            trainer.video_validation=True \
            ++trainer.ensemble_sizes_to_save='{params['ensemble_sizes_to_save']}' \
            ++trainer.validation_ensemble_size={params['validation_ensemble_size']} \
            ++trainer.max_num_samples={params['max_num_samples']} \
            ++trainer.max_spectral_val_samples=5 \
            ++trainer.common_params_warmup_epochs=5 \
            data={params['data']['name']} \
            data.module_parameters.batch_size={params['batch_size']} \
            data.module_parameters.max_samples={params['max_samples']} \
            data.module_parameters.n_steps_input=6 \
            data.module_parameters.n_steps_output=1 \
            data.module_parameters.min_dt_stride={params['data']['min_dt_stride']} \
            data.module_parameters.max_dt_stride={params['data']['max_dt_stride']} \
            optimizer=adam \
            optimizer.lr={params["optimizer_lr"]} \
            model/processor/space_mixing=full_spatial_attention \
            model.projection_dim=48 \
            model.intermediate_dim=352 \
            model.hidden_dim={params["hidden_dim"]} \
            model.groups=16 \
            model.drop_path=0.05 \
            model.processor_blocks={params["processor_blocks"]} \
            model.processor.space_mixing.num_heads=16 \
            model.processor.time_mixing.num_heads=16 \
            model.causal_in_time=True \
            model.jitter_patches=True \
            +model.use_periodic_fixed_jitter=True \
            +model.input_field_drop=0 \
            lr_scheduler={params['lr_scheduler']} \
            lr_scheduler.warmup_epochs={warmup_epochs} \
            lr_scheduler.cooldown_epochs={cooldown_epochs} \
            data_workers=10 \
            logger.wandb_project_name=CRPS_HalfWalrus \
            experiment={params['experiment']} \
            checkpoint.coalesced_checkpoint_path='{params['coalesced_checkpoint_path']}' \
            config_override='{params['config_override']}' \
            model.num_samples={params['num_samples']} \
            checkpoint.checkpoint_frequency=10 \
            model.noise_dim={params['noise_dim']} \
            trainer.enable_staged_learning=False \
            trainer.loss_fn._target_={loss_fn_target} \
            model.noise_mode={params['noise_mode']} \
            model.processor.noise_cond_dim={params['processor_noise_cond_dim']} \
            model.mlp_layers={params['mlp_layers']} \
            model.noise_layernorm={params['noise_layernorm']} \
            +model.noise_blocks='{params['noise_blocks']}' \
            ++optimizer.new_params_lr={params['optimizer_lr']} \
            ++optimizer.common_params_lr={params['common_optimizer_lr']} \
            ++optimizer.new_params_kwargs.weight_decay=0.0001 \
            ++optimizer.new_params_kwargs.eps=1e-10 \
            ++optimizer.common_params_kwargs.weight_decay=0.0001 \
            ++optimizer.common_params_kwargs.eps=1e-10 \
        """
    )

    if args.local:
        script_content = textwrap.dedent(
            f"""\
            #!/bin/bash
            export HYDRA_FULL_ERROR=1
            export NCCL_DEBUG=WARN
            export LD_LIBRARY_PATH=""
            module load ffmpeg

            {python_cmd}
            """
        )
    else:
        script_content = textwrap.dedent(
            f"""\
            #!/bin/bash
            #SBATCH --time=7-00:00:00
            #SBATCH --partition=gpu
            #SBATCH --nodes=1
            #SBATCH --gres=gpu:1
            #SBATCH -C "h100" 
            #SBATCH --ntasks-per-node=1
            #SBATCH --gpus-per-node=1
            #SBATCH --cpus-per-gpu=16
            #SBATCH --job-name={job_name}
            #SBATCH --output={slurm_folder}/{params['data']['name']}/{experiment_name}/%j_{job_name}_{date_string}.log

            export OMP_NUM_THREADS=${{SLURM_CPUS_ON_NODE}}
            export HDF5_USE_FILE_LOCKING=FALSE
            export HYDRA_FULL_ERROR=1
            export NCCL_DEBUG=WARN
            export LD_LIBRARY_PATH=""
            export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
            module load ffmpeg
            source ./venvs/walrus/bin/activate

            {python_cmd}
            """
        )

    with open(script_filename, "w") as f:
        f.write(script_content)

    if args.local:
        subprocess.run(["bash", script_filename])
    else:
        subprocess.run(["sbatch", script_filename])

    os.remove(script_filename)

    if args.single_run:
        break
