# Device configuration
device: "cuda:2"  # Change as needed

experiment_seed: 0
# Data configuration
data:
  source:
    name: "gauss0_opinion_2d"  # Source distribution
    checkpoint: "final"  # Checkpoint to load "final"/ None for random 
  target:
    name: "gauss1_opinion_2d"    # Target distribution
    checkpoint: "final"  # Checkpoint to load "final"/ None for random
  model_dir: "/models/"  # Directory containing model checkpoints

# Architecture configuration
architecture:
  input_dim: 2
  hidden_dim: 128
  num_layers: 4
  activation: "softplus"  # Options: softplus, relu, etc.
  time_varying: True

# Prior distribution configuration
prior:
  type: "gaussian"  # Options: gaussian, uniform
  dimension: 2


# Spline configuration
spline:
  type: "cubic"  # Options: linear, cubic
  num_collocation: 3  # Number of collocation points
  t_partition: 20     # Time partition points
  det_time: 1

# Potential functions
potential_functions:
  - "geodesic"
  # - "fisher_information"
  # - "entropy"
  # - "obstacle_cost_vneck"
  # - "quadartic_well"
  - "congestion_cost"
  # - "obstacle_cost_stunnel"
  # - "obstacle_cost_gmm"
coefficients_potentials:
  sigma: .0
  

opinion_dynamics:
  active: True
  S: 500
  strength: 3.0
  m_coeff: 8.0 

# Optimization configuration
optimization:

  geodesic_warmup: True
  geodesic_warmup_steps: 200
  
  coupling:
    steps: [20]          # Number of coupling optimization steps
    learning_rate: 0.0001
    weight_decay: 0.0001
    scheduler:
      type: "StepLR"
      step_size: 100
      gamma: 0.9
  path:
    steps: [20]          # Number of path optimization steps
    learning_rate: 0.001
    scheduler:
      type: "StepLR"  #cosine or StepLR
      step_size: 2 # for StepLR
      gamma: 0.1 # for Step LR
      T_0: 5 # for cosine
      T_mult: 2 # for cosine
      eta_min: 0.000001 # for cosine

  optimization_steps: [10]  # Number of optimization steps
  t_node: 10
  batch_size: 5000
  weight_boundary: 1000
  num_samples: 1000

# Trace estimation
trace_estimator: "exact"  # Options: exact, hutchinson

# Visualization configuration
visualization:
  plot_bounds:
    x_min: -8.
    x_max: 8.
    y_min: -8
    y_max: 8.
  num_plot_samples: 1000
  num_time_steps: 10
  solver: "midpoint"

# EMA configuration
ema:
  beta: 0.99
  update_after_step: 10
  update_every: 10

# Logging configuration
wandb:
  project: "Parametric Density Path Optimization"
  entity: "sebas_g"
  run_name: "${data.source.name}_to_${data.target.name}_pot_${potential_functions}"  # Include potentials
  group: "density_path_experiments"
  logging:
    metrics:
      lagrangian: "lagrangian_value"
      coupling_loss: "coupling_loss"
      path_loss: "path_loss"
      boundary_loss: "boundary_loss"
      total_loss: "total_loss"
    plots:
      path_visualization: "path_plot"
      potential_visualization: "potential_plot"
      loss_curve: "loss_history"
    model:
      spline_checkpoint: "spline_checkpoint"
      ema_checkpoint: "ema_checkpoint"
    run_info:
      iteration: "current_iteration"
      coupling_step: "coupling_phase"
      optimization_phase: "optimization_type"  # path or coupling
      learning_rate: "current_lr"