data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 2048
    num_workers: 12
    wrap: false
    train:
      target: celldiff.data.gene_pert.GenePerturbationTrain
      params:
        dataset: ???  # adamsn | dixit | norman
        post_cond_flag: true
        coexpress_threshold: 0.4
        num_similar_genes_go_graph: 20
    validation:
      target: celldiff.data.gene_pert.GenePerturbationValidation
      params:
        dataset: ???  # adamsn | dixit | norman
        post_cond_flag: true
        coexpress_threshold: 0.4
        num_similar_genes_go_graph: 20
    test:
      target: celldiff.data.gene_pert.GenePerturbationTest
      params:
        dataset: ???  # adamsn | dixit | norman
        post_cond_flag: true
        coexpress_threshold: 0.4
        num_similar_genes_go_graph: 20
model:
  base_learning_rate: 1.0e-08
  target: celldiff.models.diffusion.ddpm.DDPM
  params:
    mask_context: false
    mask_noised_context: false
    # Loss
    balance_loss: false
    loss_type: l2
    loss_strategy: recon_full  # recon_full | recon_masked
    parameterization: x0
    # Noise scheduler
    linear_start: 0.0001
    linear_end: 0.02
    log_every_t: 200
    timesteps: 1000
    # Inputs
    input_key: input
    pe_input_key: null
    cond_key: cond
    cond_to_ignore: null
    # Others
    monitor: val/corr_delta
    monitor_mode: max
    clip_denoised: true
    guidance_flag: false # according to task, true for recon rmse
    guidance_config:
      target: celldiff.models.diffusion.guidance.ClassifierFreeGuidance
      params:
        w: 0.2
        p_uncond: 0.2
    # ----------------------------- Task related ------------------------------
    t_sample: 1000  # time step for reconstruction and integration
    # Reconstruction
    recon_flag: true
    recon_sample: false  # one step generation if set to False
    # Integration
    integrate_flag: false
    latent_flag: false  # use latent embeddings for integration, otherwise use reconstructed cell
    latent_mask_ratio: 1.0  # 1.0 only uses denoising encoder, 0.0 only uses context encoder
    # Denoising
    denoise_flag: false
    denoise_t_sample: 1000
    # Imputation
    impute_flag: false
    # Classification
    classify_flag: false
    classifier_config:
      target: celldiff.models.diffusion.classifier.DiffusionClassifier
      params:
        query_mode: seen
        n_samples_list:
          - 5
        to_keep_list:
          - 1  # last to_keep must be 1
        n_trials: 1
        loss: l2
        time_step_sampler: IterativeUniform
    # Perturbation
    pert_flag: true
    pert_target_key: gene_pert_target
    path_to_save_fig: null
    # -------------------------------------------------------------------------
    model_config:
      target: celldiff.modules.diffusionmodules.model.MaskedAutoencoder
      params:
        # Global settings
        activation: gelu  # relu | gelu | sigmoid
        norm_layer: layernorm  # layernorm | batchnorm | groupnorm
        depth: 6
        dropout: 0.
        cell_mask_ratio: 0.25  # ratio of cell to be fully masked
        mask_mode: v2  # v2: fully mask some cells, and randomly mask x% of cell for the rest (x uniformly drawn)
        # Encoder
        embed_dim: 512
        dim_head: 64
        num_heads: 8
        mask_context: false  # DEPRECATED
        decoder_embed_dim: 512
        decoder_dim_head: 64
        decoder_num_heads: 8
        mlp_time_embed: false
        # Conditioner
        cond_type: crossattn  # crossattn | mlp | stackffn
        cond_strategy: full_mix  # full_mix | pre_mix
        cond_emb_type: embedding
        cond_cat_input: false
        cond_tokens: 1  # number of tokens to expand
        cond_mask_ratio: 0.0  # masking portion of conditions during training
        # Gene perturbation specific configs
        num_go_gnn_layers: 1
        gears_hidden_size: 512
        gears_mode: mlpparallel  # single | parallel | sequential | mlpparallel
        gears_mlp_layers: 1
        gears_norm: layernorm  # null | layernorm | batchnorm | groupnorm
        # Post conditioner
        post_cond_layers: 1
        post_cond_type: add  # only support 'add' now
        post_cond_norm: batchnorm  # layernorm | batchnorm | groupnorm  (not used when post_cond_layers=1)
        post_cond_mask_ratio: 0.0
        # Mask decoder conditioner (applied to denoising encoder)
        mask_dec_cond: false
        mask_dec_cond_se: false  # se and semlp are mutually exclusive options
        mask_dec_cond_semlp: false  # se and semlp are mutually exclusive options
        mask_dec_cond_ratio: true  # only valid if se and semlp are turned off
        mask_dec_cond_concat: false
        # Decoder
        decoder_embed_type: embedder
        decoder_mask: inv_enc
        encoder_type: mlp  # attn | mlp | stackffn | mlpparallel | ffnparallel
lightning:
  trainer:
    check_val_every_n_epoch: 5
    log_every_n_steps: 10
    max_epochs: 300
    devices:
    - 0
