seed_everything: 42

data:
  class_path: data_model.datamodel_tri_test.MyDataModule
  init_args:
    data_name: FasionMnist
    pca_dim: 64
    batch_size: 5000
    num_workers: 20
    K: 30
    n_f_per_cluster: 3
    l_token: 10
    data_path: /root/data/
    len_multiple: 50


model:
  class_path: model.DiffTreeVQ.DMTEVT_model
  init_args:
    num_input_dim: 784
    # sigma: 0.05
    lr: 0.005 # 0.00002 # 0.001, 0.005, 0.01
    # w_nb: 1 # not change
    # w_fp: 1 # not change
    # weight_nepo: 0.9 # 0.1, 0.2, 0.5, 1, 2, 5
    # weight_mse: 0 # 0.1, 0.2, 0.5, 1, 2, 5
    # sample_rate_feature: 0.9
    nu_lat: 0.5
    nu_emb: -1
    exaggeration_lat: 0.2
    exaggeration_emb: -1
    sample_rate_feature: 0.99
    # T_num_attention_heads: 48
    T_hidden_size: 512
    t_output_dim: 784
    weight_decay: 0.0000000001
    num_use_moe: 1
    T_num_layers: 3
    max_epochs: 1000
    tree_depth: 10
    step2_epoch: 40
    step2_r_epoch: 100
    use_tree_rout: True
    gen_data_bool: False
    ec_ce_weight: 2
    weightrout: 0.2
    # p3: 1
    # use_orthogonal: True
    # T_intermediate_size: 300
    # T_hidden_dropout_prob: 0.0
    # T_attention_probs_dropout_prob: 0.0
    # num_muti_mask: 10
    # vis_dim: 2
    # trans_out_dim: 2048
    # n_neg_sample: 4
trainer:
  logger:
    class_path: lightning.pytorch.loggers.WandbLogger
    init_args:
      name: difftree_example
      project: DiffTree
      save_dir: wandb
  callbacks:
    - class_path: call_backs.eval_difftree_vq_validate.EvalCallBack
      init_args:
        inter: 2
        dirpath: 'zzl_checkpoints/'
        fully_eval: False
        dataset: 'mnist_4gpu'
        vis_rout: False
    # - class_path: call_backs.MaskExp.MaskExpCallBack
    #   init_args:
    #     inter: 100
  max_epochs: 300
  devices: 1
  accelerator: gpu
  # strategy: ddp_find_unused_parameters_true
  check_val_every_n_epoch: 20
  enable_checkpointing: True