dataset:
  add_dim_keys:
    test: !!python/tuple
    - drift_at_observations
    train: !!python/tuple
    - drift_at_observations
    validation: !!python/tuple
    - drift_at_observations
  add_paths_keys:
    test: !!python/tuple
    - drift_at_observations
    train: !!python/tuple
    - drift_at_observations
    validation: !!python/tuple
    - drift_at_observations
  batch_size:
    test: 32
    train: 64
    validation: 32
  data_dirs:
    test: !!python/tuple
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_3
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_2
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_1
    train: !!python/tuple
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_3
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_2
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_1
    validation: !!python/tuple
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_3
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_2
    - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_1
  dataset_name:
    test: HeterogeneousFIMSDEDataset
    train: StreamingFIMSDEDataset
    validation: StreamingFIMSDEDataset
  files_to_load:
    drift_at_locations: drift_at_locations.h5
    drift_at_observations: drift_at_observations.h5
    locations: locations.h5
    obs_mask: obs_mask.h5
    obs_times: obs_times.h5
    obs_values: obs_values.h5
  max_dim: 3
  name: FIMSDEDataloaderIterableDataset
  num_locations:
    test: null
    train: 2000
    validation: 10000
  num_observations:
    test: null
    train: !!python/tuple
    - 0
    - 1801
    validation: !!python/tuple
    - 1799
    - 1801
  num_workers:
    test: 0
    train: 7
    validation: 5
  shard:
    test: false
    train: true
    validation: true
  shuffle_elements: true
  shuffle_locations:
    test: false
    train: true
    validation: true
  shuffle_paths: true
distributed:
  activation_chekpoint: false
  checkpoint_type: full_state
  enabled: true
  min_num_params: 1e5
  sharding_strategy: NO_SHARD
  wrap_policy: SIZE_BAZED
experiment:
  device_map: cuda
  name: big_model_l1_600k_examples
  name_add_date: true
  seed: 10
model:
  model_config:
    attention_map: softmax
    attention_method: linear
    dim_embed: 256
    dim_feedforward: 1024
    dim_ffn_u_model: 1024
    dim_hidden_u_model: 256
    dim_max_trajectory: 3
    dropout: 0.1
    num_context_encoder_layers: 2
    num_heads: 8
    num_res_layer_u_model: 6
    num_res_layers_functional_decoder: 8
    use_bias_for_projection: true
    use_bias_in_attention: true
    use_query_residual_in_attention: true
  model_type: TrainingWrapper
  train_config:
    corruption_model_type: odeformer
    loss_filter_nans: true
    loss_type: l1
    max_sigma_trajectory_noise: 0.06
    max_subsampling_ration: 0.5
    train_type: vector_field
    train_with_normalized_head: true
optimizers: !!python/tuple
- optimizer_d:
    gradient_norm_clipping: 10
    lr: 1.0e-05
    name: torch.optim.AdamW
    weight_decay: 0.0001
trainer:
  best_metric: loss
  debug_iterations: null
  detect_anomaly: false
  epochs: 2500
  experiment_dir: ./results/
  gradient_accumulation_steps: 1
  logging_format: RANK_%(rank)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s
  name: Trainer
  precision: bf16mixed
  save_every: 1
  schedulers: !!python/tuple
  - beta: 1.0
    label: drift_loss_scale
    name: fim.utils.param_scheduler.ConstantScheduler
