# Device configuration
device: "cuda:3"  # Change as needed

experiment_seed: 2

# Data configuration
data:
  source:
    name: "gaussian0"  # Source distribution
    checkpoint: "final"  # Checkpoint to load "final"/ None for random 
  target:
    name: "gaussian1"    # Target distribution
    checkpoint: "final"  # Checkpoint to load "final"/ None for random
  model_dir: "/models/"  # Directory containing model checkpoints

# Architecture configuration
architecture:
  input_dim: 2
  hidden_dim: 128
  num_layers: 4
  activation: "softplus"  # Options: softplus, relu, etc.
  time_varying: True

# Prior distribution configuration
prior:
  type: "gaussian"  # Options: gaussian, uniform
  dimension: 2
  

# Spline configuration
spline:
  type: "cubic"  # Options: linear, cubic
  num_collocation: 3  # Number of collocation points
  t_partition: 60     # Time partition points
  det_time: 1

# Potential functions
potential_functions:
  # - "geodesic"
  - "fisher_information"
  - "entropy"
  - "obstacle_cost_vneck"
  # - "quadartic_well"
  # - "congestion_cost"
  # - "obstacle_cost_stunnel"
  # - "obstacle_cost_gmm"
coefficients_potentials:
  sigma: 1.0
  weight_obstacle: 1500
  weight_interaction: 50


  

# Optimization configuration
optimization:

  geodesic_warmup: True
  geodesic_warmup_steps: 100
  
  coupling:
    steps: [20]          # Number of coupling optimization steps
    learning_rate: 0.0001
    weight_decay: 0.001
    scheduler:
      step_size: 5
      gamma: 0.1
  path:
    steps: [20]          # Number of path optimization steps
    learning_rate: 0.005
    scheduler:
      type: "StepLR"  #cosine or StepLR
      step_size: 5 # for StepLR
      gamma: 0.25 # for Step LR
      T_0: 5 # for cosine
      T_mult: 2 # for cosine
      eta_min: 0.000001 # for cosine

  optimization_steps: [30]
  t_node: 10
  batch_size: 1500
  weight_boundary: 1000
  num_samples: 1000

# Trace estimation
trace_estimator: "exact"  # Options: exact, hutchinson

# Visualization configuration
visualization:
  plot_bounds:
    x_min: -15.
    x_max: 15.
    y_min: -15
    y_max: 15.
  num_plot_samples: 1000
  num_time_steps: 10
  solver: "midpoint"

# EMA configuration
ema:
  beta: 0.99
  update_after_step: 10
  update_every: 10

# Logging configuration
wandb:
  project: "Parametric Density Path Optimization"
  entity: "sebas_g"
  run_name: "${data.source.name}_to_${data.target.name}_pot_${potential_functions}"  # Include potentials
  group: "density_path_experiments"
  logging:
    metrics:
      lagrangian: "lagrangian_value"
      coupling_loss: "coupling_loss"
      path_loss: "path_loss"
      boundary_loss: "boundary_loss"
      total_loss: "total_loss"
    plots:
      path_visualization: "path_plot"
      potential_visualization: "potential_plot"
      loss_curve: "loss_history"
    model:
      spline_checkpoint: "spline_checkpoint"
      ema_checkpoint: "ema_checkpoint"
    run_info:
      iteration: "current_iteration"
      coupling_step: "coupling_phase"
      optimization_phase: "optimization_type"  # path or coupling
      learning_rate: "current_lr"