meta_data:
  basin_map_csv_path: meta_data/mapping_basin_to_state.csv
  population_csv_path: meta_data/populations.csv
  srcPath: ..
  dataPath: ./data
  metaPath: ./meta_data
  df_meta_all_path: df_meta_all.csv

# Model registry for automatic loading
model_registry:
  default_model: "v1.0"
  models:
    v1.0:
      checkpoint_path: "model_weights/checkpoints/checkpoint_iter_0000000200.pt"
      y_stats_path: "model_weights/active_logs/y_stats_iter_0000000200.parquet"
      description: "Default STNP model trained for 200 iterations"
      version: "1.0"
      created_date: "2024-01-01"
    v0.9:
      checkpoint_path: "model_weights/checkpoints/checkpoint_iter_0000000100.pt"
      y_stats_path: "model_weights/active_logs/y_stats_iter_0000000100.parquet"
      description: "STNP model trained for 100 iterations"
      version: "0.9"
      created_date: "2024-01-01"

#
model:
  NUM_COMP: 4 # number of output compartments
  POPULATION_SCALER: 1_000_000
  MAX_SEQ_LENGTH: 245
  #total dimension including temporal dim
  define: &x_dim 7
  x_dim: *x_dim
  #just temporal dimension
  xt_dim: 1
  y_dim: 51
  num_nodes: 51
  z_dim: 100
  r_dim: 100
  seq_len: 244
  in_channels: *x_dim
  embed_out_dim: 256
  out_channels: 32
  max_diffusion_step: 1
  #encoder
  encoder_num_rnn: 1
  #decoder
  decoder_num_rnn: 1
  decoder_hidden_dims: [1024, 1024, 512]
  #z_enocder
  context_percentage: .3

train: 
  max_epochs: 120
  lr: .0001
  lr_encoder: .0001
  lr_decoder: .0001
  lr_milestones: [40 ,50, 60]
  lr_gamma: .8
  patience: 10
  gradient_clip_val: 1
  l1_reg: .001
  train_batch_size: 16
  train_num_workers: 8
  val_num_workers: 8
  val_batch_size: 16

#active learner
active_learner:
  max_iter: 200
  initial_train_size: 20
  batch_size_stat_compute: 64 
  retriever_num_workers: 8
  pool_num_workers: 8
  every_step: 5
  lr_gamma: .9


mstd:
  acquisition_size: 10
  acquisition_pool_fraction: .05
  pool_loader_batch_size: 24
lig: 
  acquisition_size: 10
  pool_loader_batch_size: 40
  acquisition_pool_fraction: .05
  

data:
  x_col_names: ["R0", "days"]
  frac_pops_names: ["Susceptible", "Latent","Recovered"]
  initial_col_names: ["Latent", "Latent_vax", "Hospitalized_symptomatic", "Hospitalized_symptomatic_vax"]
  output_compartments: ["Hospitalized", "Latent"]
  initial_val_size: 120
  initial_test_size: 120

output:
  # NOTE: These names must match the hardcoded values in gleam_ai/models/stnp.py
  compartment_names: ["hosp_inc", "hosp_prev", "latent_inc", "latent_prev"]
  # Full descriptive names for documentation
  compartment_full_names: ["Hospitalized Incidence", "Hospitalized Prevalence", "Latent Incidence", "Latent Prevalence"]

# temporal seasonlaity will be added to first dim of x
# frac_pops are fractions
# population/1M will be added as a feature
# loss function: crps_gaussian


