# Base configuration for JKOnet, JKOnet-Star, and InverseJKO models.
# Model-specific settings should be overridden in their corresponding configuration files.
# Note: For JKOnet-Star, this base configuration is complete—no additional config files are required.

# training
train:
  eval_freq: 250 # Frequency (in epochs) for evaluating the model
  batch_size: 250 # Number of samples per gradient update
  epochs: 15000 # Total number of epochs for training
  dt: 0.01
  save_locally: True # If True, plots will be saved locally.

metrics:
  wasserstein_order: 1 # Order of the Wasserstein error to be computed
  compute_one_ahead:
    enabled: True
    types: ["MMD", "EMD", "BW2_UVP", "L2_UVP_potential_backward","L2_UVP_interaction_backward", "L2_UVP_beta","EMD_Tong","MMD_DMSB"]
  compute_cumulative:
    enabled: False
    types: ["wasserstein"]

#WandB
wandb:
  save_plots: True # If True, plots will be saved in wandb
  save_model: False # If True, model will be saved in wandb

# models
energy:
  # optimization
  optim:
    weight_decay: 0.01
    optimizer: Adam # Choice of optimizer for updating model parameters

    # Adam optimizer parameters
    lr: 0.01
    beta1: 0.9
    beta2: 0.999
    eps:  0.00000001

    grad_clip: 10.0 # Gradient clipping threshold



  # model architecture
  model:
    layers: [64,64] # Number of units in each layer of the neural network

  # feature selection for linear parametrization
  linear:
    reg: 0.01 # Regularization term for the linear parametrization
    features:
      polynomials:
        degree: 4 # Degree of polynomial features
        sines: False # Enable sine functions of the polynomials as additional features
        cosines: False # Enable cosine functions of the polynomials as additional features
      rbfs:
        n_centers_per_dim: 10 # Number of radial basis function centers per dimension
        domain: [-4, 4] # Domain for radial basis functions
        sigma: 0.5 # Spread (sigma) for radial basis functions
        # types of rbfs to include
        types: [
            # 'linear',
            # 'thin_plate_spline',
            # 'cubic',
            # 'quintic',
            "const",
            # 'multiquadric',
            # 'inverse_multiquadric',
            # 'inverse_quadratic'
          ]
