_target_: ewfm.models.ewfm_module.EWFMModule

optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 0.0005
  weight_decay: 0.0

defaults:
  - net:
      - mlp

algorithm: "baseline"
baseline_until_epoch: 10

# Algorithm specific parameters
# Bagging
bagging_buffer_size: 10000

# Training parameters
use_train_data: false
num_samples_per_batch: 5000
vector_field_max_norm: null
device: ${trainer.accelerator}

# Batched sampling configuration (optional)
batched_sampling: false
sample_batch_size: 1000

# Plotting parameters
enable_detailed_train_logging: false

# Validation parameters
validation_uniform_samples: 5000
likelihood_plot_spacing: 0.01
contour_plot_levels: 20
flow_num_particles: 150
eval_batch_size: 1000
val_plot_batch_size: 1000
test_batch_size: 10000

# Solver parameters
step_size: 0.01
integration_method: "dopri5"
atol: 1e-5
rtol: 1e-5
use_exact_divergence: true

# Metric computation parameters
metric_batch_size: null # If set, compute expensive metrics (ESS, NLL) in smaller batches

seed: ${seed}

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.1
  patience: 10

lr_scheduler_update_frequency: ${trainer.check_val_every_n_epoch}

# Annealing parameters (aEWFM)
enable_annealing: true
initial_temperature: 10.0
final_temperature: 1.0
temperature_schedule: "geometric" # "geometric", "linear"
annealing_epochs_per_temperature: 10
total_annealing_epochs: 100
temperature_values: null # For custom schedule

# Sample and weight clipping parameters
clipping_method: null # null, "importance_weight", "energy_value", "modified_energy"
clipping_percentile: 99.9 # Clip at the (100-x)th percentile, i.e., top x% of values

# EMA parameters
use_ema: false
ema_beta: 0.999
beta_warmup_denominator: 10

# Value needed for Python-based calculation
data_n_train_batches_per_epoch: ${data.n_train_batches_per_epoch}
