# training
train:
  eval_freq: 250 # Frequency (in epochs) for evaluating the model
  batch_size: 500 # Number of samples per gradient update
  epochs: 5000 # Total number of epochs for training
  dt: 0.01 # Discretization parameter
  save_locally: True # If True, plots will be saved locally.
  energy_gp_reg: 0

metrics:
  wasserstein_order: 1 # Order of the Wasserstein error to be computed
  compute_one_ahead:
    enabled: True
    types: [
        "MMD",
        "EMD",
        "BW2_UVP",
        "L2_UVP_potential_backward",
        "EMD_Tong",
        "MMD_DMSB"
        # "L2_UVP_interaction_backward",
        # "L2_UVP_beta",
      ]
  compute_cumulative:
    enabled: False
    types: ["wasserstein"]

#WandB
wandb:
  save_plots: True # If True, plots will be saved in wandb
  save_model: False # If True, model will be saved in wandb

# models
energy:
  # optimization
  optim:
    weight_decay: 0.00000484012
    optimizer: Adam # Choice of optimizer for updating model parameters

    # Adam optimizer parameters
    lr: 0.000271951
    beta1: 0.58736
    beta2: 0.928005
    eps: 0.00000001

    grad_clip: 18.7228

    # scheduler: piecewise_constant_schedule # specify parameters below or see default parameters in code

  # model architecture
  model:
    layers: [128, 128, 128] # Number of units in each layer of the neural network

  # feature selection for linear parametrization
  linear:
    reg: 0.01 # Regularization term for the linear parametrization
    features:
      polynomials:
        degree: 4 # Degree of polynomial features
        sines: False # Enable sine functions of the polynomials as additional features
        cosines: False # Enable cosine functions of the polynomials as additional features
      rbfs:
        n_centers_per_dim: 10 # Number of radial basis function centers per dimension
        domain: [-4, 4] # Domain for radial basis functions
        sigma: 0.5 # Spread (sigma) for radial basis functions
        # types of rbfs to include
        types: [
            # 'linear',
            # 'thin_plate_spline',
            # 'cubic',
            # 'quintic',
            "const",
            # 'multiquadric',
            # 'inverse_multiquadric',
            # 'inverse_quadratic'
          ]
