data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 2048
    num_workers: 4
    wrap: false
    train:
      target: celldiff.data.cellxgene.CellXGeneTopKFewShotTrain
      params:
        dataset: HLCA_sub
        post_cond_flag: true
        save_processed: false
        text_cond_flag: false  ## star config
        threshold: 1000
        choice: top
        num_cell_types: 3
    test:
      target: celldiff.data.cellxgene.CellXGeneTopKFewShotTest
      params:
        dataset: HLCA_sub
        post_cond_flag: true
        save_processed: false
        text_cond_flag: false  ## star config
        threshold: 1000
        choice: top
        num_cell_types: 3
# specify pretrained_ckpt_path here or in commman line
model:
  base_learning_rate: 5.0e-09
  target: celldiff.models.diffusion.ddpm.DDPM
  params:
    mask_context: false
    mask_noised_context: false
    balance_loss: false
    linear_start: 0.0001
    linear_end: 0.02
    log_every_t: 200
    timesteps: 1000
    t_sample: 1000
    input_key: input
    pe_input_key: null
    cond_key: cond
    loss_type: l2
    loss_strategy: recon_full  # recon_full | recon_masked
    parameterization: x0
    monitor: val/loss_simple_ema
    cond_to_ignore: null
    clip_denoised: true
    latent_flag: false
    latent_mask_ratio: 1.0
    recon_flag: false
    recon_sample: true  # one step generation if set to False
    integrate_flag: false
    denoise_flag: false
    denoise_t_sample: 1000
    impute_flag: false
    pert_flag: false
    text_cond_flag: true  ## star config
    classify_flag: true
    classifier_config:
      target: celldiff.models.diffusion.classifier.DiffusionClassifier
      params:
        query_mode: seen
        n_samples_list:
          - 50
        to_keep_list:
          - 1  # last to_keep must be 1
        n_trials: 1
        loss: l2
        time_step_sampler: IterativeUniform
    guidance_flag: false # according to task, true for recon rmse
    guidance_config:
      target: celldiff.models.diffusion.guidance.ClassifierFreeGuidance
      params:
        w: 0.2
        p_uncond: 0.2
    model_config:
      target: celldiff.modules.diffusionmodules.model.MaskedAutoencoder
      params:
        mask_context: false  # DEPRECATED
        activation: gelu  # relu | gelu | sigmoid
        cell_mask_ratio: 0.25
        mask_mode: v2
        embed_dim: 512
        dim_head: 64
        num_heads: 8
        norm_layer: groupnorm  # layernorm | batchnorm | groupnorm
        decoder_embed_dim: 512
        decoder_dim_head: 64
        decoder_num_heads: 8
        depth: 6
        mlp_time_embed: false
        dropout: 0.
        decoder_embed_type: embedder
        decoder_mask: inv_enc
        encoder_type: mlp  # attn | mlp | stackffn
        cond_type: crossattn  # crossattn | mlp | stackffn
        cond_emb_type: embedding
        cond_tokens: 1  # number of tokens to expand
        cond_mask_ratio: 0.1  # masking portion of conditions during training
        post_cond_layers: 1
        post_cond_type: add  # only support 'add' now
        post_cond_norm: batchnorm  # layernorm | batchnorm | groupnorm  (not used when post_cond_layers=1)
        post_cond_mask_ratio: 0.0
        mask_dec_cond: false
        mask_dec_cond_se: false  # se and semlp are mutually exclusive options
        mask_dec_cond_semlp: false  # se and semlp are mutually exclusive options
        mask_dec_cond_ratio: true  # only valid if se and semlp are turned off
        mask_dec_cond_concat: false
        cond_emb_norm: groupnorm # layernorm | batchnorm | groupnorm | null
lightning:
  trainer:
    log_every_n_steps: 100
    max_epochs: 50
    devices:
    - 0
