method_name: mlem
runner_name: UnsupervisedRunner

data:
  loaders:
    unsupervised_train: ${data.loaders.train}
    unsupervised_train_val: ${data.loaders.train_val}

unsupervised_loss:
  name: ModelLoss
unsupervised_metrics:
  - name: MLEM_total_mse_loss
  - name: MLEM_total_CE_loss
  - name: MLEM_sparcity_loss
  - name: MLEM_reconstruction_loss

unsupervised_trainer:
  ckpt_track_metric: loss
  metrics_on_train: true
  patience: ${patience}
  total_iters: 100_000

trainer:
  ckpt_track_metric: ${main_metric}
  metrics_on_train: true
  patience: ${patience}
  total_iters: 100_000

unsupervised_model:
  preprocess:
    name: MLEMPretrainer
    params: ${model.preprocess.params}

coles_time_process:
  log/age: cat
  log/taobao: diff
  log/physionet2012: cat
  log/mimic3: diff
  log/pendulum: cat
  log/x5: none
  log/mbd: cat

model:
  preprocess:
    name: MLEMEncoder
    params:
      contr_model_folder: "${log_dir}/coles/correlation/seed_0/"
      normalize_z: False
      # Preprocess
      cat_cardinalities: ${cc}
      num_features: ${nn}
      cat_emb_dim: 16
      num_emb_dim: 16
      time_process: ${coles_time_process.${log_dir}}
      num_norm: True
      # Encoder:
      enc_num_layers: 1
      enc_aggregation: TakeLastHidden
      # Decoder:
      dec_hidden_size: 128
      dec_num_layers: 3
      dec_num_heads: 8
      dec_scale_hidden: 2
      max_len: ${data.preprocessing.common_pipeline.max_seq_len}
      # Loss weights:
      l1_weight: 0.001
      contrastive_weight: 10
  head:
    name: nn.Linear
    params:
      in_features: output_dim_from_prev
      out_features: ${n_classes}

optimizer:
  name: Adam
  params:
    lr: 1.e-3
    weight_decay: 1.e-8

optuna:
  params:
    n_trials: 70
    n_startup_trials: 5
    request_list: 
      - 'optimizer.params.weight_decay': 1.e-08
        'optimizer.params.lr': 1.e-3
        'model.preprocess.params.normalize_z': false
        # 'model.preprocess.params.time_process': 'diff'
        'model.preprocess.params.num_norm': True
        'model.preprocess.params.cat_emb_dim': 16
        'model.preprocess.params.num_emb_dim': 16
        'model.preprocess.params.enc_num_layers': 3
        'model.preprocess.params.dec_hidden_size': 128
        'model.preprocess.params.dec_num_layers': 3
        'model.preprocess.params.dec_num_heads': 8
        'model.preprocess.params.dec_scale_hidden': 4
        'model.preprocess.params.contrastive_weight': 10
        # 'unsupervised_trainer.total_iters': 100_000
    target_metric: ${main_metric}
  suggestions:
    optimizer.params.weight_decay: [suggest_float, {low: 1.e-10, high: 1.e-2, log: True}]
    optimizer.params.lr: [suggest_float, {low: 1.e-5, high: 1.e-2, log: True}]
    # MLEM:
    model.preprocess.params.normalize_z: [suggest_categorical, {choices: [False, True]}]
    # Preprocess
    # model.preprocess.params.time_process: [suggest_categorical, {choices: ["diff", "cat", "none"]}]
    model.preprocess.params.num_norm: [suggest_categorical, {choices: [true, false]}]
    model.preprocess.params.cat_emb_dim: [suggest_int, {low: 8, high: 128, log: False, step: 8}]
    model.preprocess.params.num_emb_dim: [suggest_int, {low: 8, high: 128, log: False, step: 8}]
    # Encoder:
    model.preprocess.params.enc_num_layers: [suggest_int, {low: 1, high: 3, log: False}]
    model.preprocess.params.enc_aggregation: [suggest_categorical, {choices: ["TakeLastHidden", "ValidHiddenMean"]}]
    # Decoder:
    model.preprocess.params.dec_hidden_size: [suggest_int, {low: 32, high: 512, log: False, step: 32}]
    model.preprocess.params.dec_num_layers: [suggest_int, {low: 1, high: 3, log: False}]
    model.preprocess.params.dec_num_heads: [suggest_categorical, {choices: [1, 2, 4, 8]}]
    model.preprocess.params.dec_scale_hidden: [suggest_int, {low: 1, high: 16, log: False}]
    # Loss weights:
    # model.preprocess.params.l1_weight: [suggest_float, {low: 1.e-10, high: 1.e-2, log: True}]
    model.preprocess.params.contrastive_weight: [suggest_float, {low: 1, high: 100, log: False}]

    # unsupervised_trainer.total_iters: [suggest_categorical, {choices: [0, 100_000]}]
