output_dir: &output_dir outputs/

model_name: &model_name meta-llama/Llama-2-7b-chat-hf
kgl_token_length: 10


dataset:
  class: WN18RRInductive
  path: data/datasets/

tokenizer:
  pretrained_model_name_or_path: *model_name
  use_fast: no
  add_eos_token: no

mkglconfig:
  pretrained_model_name_or_path: *model_name

mkgl:
  pretrained_model_name_or_path: *model_name
  load_in_8bit: no

loraconfig:
  r: &r 32
  lora_alpha: 8
  lora_dropout: 0.05
  target_modules: 
  # - embed_tokens
  # - lm_head
  - q_proj
  - v_proj

trainer:
  output_dir: *output_dir
  num_train_epochs: 30
  save_total_limit: 1
  per_device_train_batch_size: 8
  per_device_eval_batch_size: 32
  evaluation_strategy: epoch
  save_strategy: 'no' #epoch
  warmup_steps: 50
  bf16: yes
  # logging_steps: 10
  logging_strategy: epoch
  learning_rate: 1.0e-3
  gradient_accumulation_steps: 4
  eval_accumulation_steps: 64
  save_safetensors: no
  remove_unused_columns: no
  label_names:
  - label
  optim: adamw_8bit
  max_grad_norm: 1.
  ddp_find_unused_parameters: yes
  report_to:
  - tensorboard

mkgl4kgc:
  criterion: bce
  num_negative: 32
  strict_negative: yes
  adversarial_temperature: 1

context_retriever:
  llm_hidden_dim: &llm_hidden_dim 4096
  r: *r
  text_encoder: pna
  kg_encoder:
    class: PNA
    base_layer:
      class: PNALayer
      input_dim: *r
      output_dim: *r
      query_input_dim: *r
      message_func: distmult
      aggregate_func: pna
      layer_norm: yes
      dependent: no
    num_layer: 1
    node_ratio: 0.05
    test_node_ratio: 1

score_retriever:
  llm_hidden_dim: *llm_hidden_dim
  r: *r
  text_encoder: pna
  kg_encoder:
    class: ConditionedPNA
    base_layer:
      class: PNALayer
      input_dim: *r
      output_dim: *r
      query_input_dim: *r
      message_func: distmult
      aggregate_func: pna
      layer_norm: yes
      dependent: no
    num_layer: 6
    node_ratio: 0.05
    test_node_ratio: 1