_target_: src.models.InContextLearnerV2

embedding_size: ${encoding_extractor.embedding_size}

network:
  _target_: src.models.incontext_learner_v2.GPTJModelV2
  config:
    _target_: transformers.GPTJConfig
    n_positions: 4096  # can fit context_class_size = 2048 / 3 ~= 682
    n_embd: 768
    n_layer: 12
    n_head: 12
    n_inner: 3072  # Adjusted as per convention: 4 * n_embd
    resid_pdrop: 0.0
    embd_pdrop: 0.0
    attn_pdrop: 0.0
    use_cache: False

loss_fn:
  _target_: torch.nn.BCEWithLogitsLoss
#loss_fn:
#  _target_: src.utils.custom_loss_functions.ContextBNEWithLogits
#  query_loss_weight: 1

val_sets: ${datamodule.val_sets} # null if there is a single val set, otherwise this is the list of names of val sets

dataset_name: ${datamodule.name}

input_layer_norm: False
