model_class: NDT1

# NDT1 stitching baseline: temporal masking/stitching
encoder:

  from_pt: null

  stitching: true

  # Mask spikes
  masker:
    force_active: true         
    mode: temporal      # masking mode
    ratio: 0.3          # ratio of data to predict
    zero_ratio: 1.0     # of the data to predict, ratio of zeroed out
    random_ratio: 1.0   # of the not zeroed, ratio of randomly replaced
    expand_prob: 0.0    # probability of expanding the mask in ``temporal`` mode
    max_timespan: 1     # max span of mask if expanded
    channels: null
    # [579, 21, 486, 316, 521, 71, 32, 561, 445, 470, 205, 41, 230, 306, 592, 148, 484, 592, 181, 94, 120, 601, 414, 497, 263, 447, 16, 537, 396, 323, 68, 444, 325, 137, 223, 434, 131, 363, 280, 419, 332, 218, 130, 580, 381, 560, 576, 145, 280, 234, 258, 37, 9, 324, 199, 610, 529, 196, 164, 348]     # neurons to mask in ``co-smooth`` mode
    timesteps: null             # time steps to mask in ``forward-pred`` mode
    mask_regions: ['all']       # brain regions to mask in ``inter-region`` mode
    target_regions: ['all']    # brain regions to predict in ``intra-region`` mode
    n_mask_regions: 1       # number of regions to choose from the list of mask_regions or target_regions
    
  # Context available for each timestep
  context:
    forward: -1
    backward: -1

  # Normalize and add noise
  norm_and_noise:
    active: false
    smooth_sd: 2                # gaussian smoohing  
    norm: "zscore"              # which normalization layer to use (null/layernorm/scalenorm/zscore)
    eps: 1.e-7                  # avoid dividing by zero when normalizing padded spikes     
    white_noise_sd: 1.0         # gaussian noise added to the inputs  1.0 originally
    constant_offset_sd: 0.2     # gaussian noise added to the inputs but contsnat in the time dimension 0.2 originally
    

  # Embedding layer
  embedder:
    n_channels: 668       # number of neurons recorded 
    n_blocks: 24          # number of blocks of experiments
    n_dates: 24           # number of days of experiments
    max_F: 100           # max feature len in timesteps

    mode: linear          # linear/embed/identity
    mult: 2               # embedding multiplier. hiddden_sizd = n_channels * mult
    adapt: false          # adapt the embedding layer for each day
    pos: true             # embed position 
    act: softsign         # activation for the embedding layers
    scale: 1              # scale the embedding multiplying by this number
    bias: true            # use bias in the embedding layer
    dropout: 0.2          # dropout in embedding layer
    
    fixup_init: false     # modify weight initialization
    init_range: 0.1       # initialization range for embeddings
    spike_log_init: false # special initialization 
    max_spikes: 0         # max number of spikes in a single time bin

    tokenize_binary_mask: false
    use_prompt: false
    use_session: true

    stack:
      active: false        # wether to stack consecutive timesteps
      size: 32            # number of consecutive timesteps to stack
      stride: 4           # stacking stride


  # Transformer
  transformer:
    n_layers: 5           # number of transformer layers
    hidden_size: 512     # hidden space of the transformer
    use_scalenorm: false  # use scalenorm  instead of layernorm
    use_rope: false       # use rotary postional encoding
    rope_theta: 10000.0   # rope angle of rotation


    n_heads: 8            # number of attentiomn heads
    attention_bias: true  # learn bias in the attention layers

    act: gelu             # activiation function in mlp layers
    inter_size: 1024      # intermediate dimension in the mlp layers
    mlp_bias: true        # learn bias in the mlp layers
    
    dropout: 0.4          # dropout in transformer layers
    fixup_init: true      # modify weight initialization

  # Projection to factor space
  factors:  
    active: false  # project from hidden_size to factors
    size: 8                  # factors size  
    act: relu                 # activation function after projecting to factors
    bias: true                # use bias in projection to factors
    dropout: 0.0              # dropout in projection to factors
    fixup_init: false         # modify weight initialization
    init_range: 0.1           # initialization range for factors projetion
    
decoder:
  from_pt: null






