# Model Configuration
model:
  model_name: "facebook/contriever"
  revision: "main"
  use_fast_tokenizer: true
  predictor:
    num_layers: 12
    input_dim: 768
    hidden_dim: 256
    output_dim: 1
    max_seq_len: 5000
    interpolate_factor: 1 # used only in action embedder for chunks positional index interpolation
                          # step between idx[i] and idx[i+1] is equal to 1./interpolate_factor

action_model:
    _target_: rl.bert_predictor.BertPredictor
    bert:
      _target_: transformers.AutoModel.from_pretrained
      pretrained_model_name_or_path: ${algo.model.model_name}
      revision: ${algo.model.revision}
    num_hidden_layers: ${algo.model.predictor.num_layers}
    tokenizer:
      _target_: transformers.AutoTokenizer.from_pretrained
      pretrained_model_name_or_path: ${algo.model.model_name}
      revision: ${algo.model.revision}
      use_fast: ${algo.model.use_fast_tokenizer}
    model_dim: ${algo.model.predictor.input_dim}
    output_size: ${algo.model.predictor.hidden_dim}
    n_output: ${algo.model.predictor.output_dim}


action_embed_dict:
  absolute:
    _target_: rl.bert_predictor.EmbedderWithAbsoluteEncoding
    model: ${algo.action_model}
    max_seq_len: ${algo.model.predictor.max_seq_len}

  random:
    _target_: rl.bert_predictor.EmbedderWithAbsoluteEncoding
    model: ${algo.action_model}
    max_seq_len: ${algo.model.predictor.max_seq_len}
    interpolate_factor: ${algo.model.predictor.interpolate_factor}
  
  relative:
    _target_: rl.bert_predictor.EmbedderWithRelativeEncoding
    model: ${algo.action_model}
    max_seq_len: 1000

  none:
    _target_: rl.bert_predictor.EmbedderNone
    model: ${algo.action_model}


# PQN Configuration
pqn:
  _target_: rl.pqn.PQN
  
  state_embed: &bert_predictor
    _target_: rl.bert_predictor.BertPredictor
    bert:
      _target_: transformers.AutoModel.from_pretrained
      pretrained_model_name_or_path: ${algo.model.model_name}
      revision: ${algo.model.revision}
    num_hidden_layers: ${algo.model.predictor.num_layers}
    tokenizer:
      _target_: transformers.AutoTokenizer.from_pretrained
      pretrained_model_name_or_path: ${algo.model.model_name}
      revision: ${algo.model.revision}
      use_fast: ${algo.model.use_fast_tokenizer}
    model_dim: ${algo.model.predictor.input_dim}
    output_size: ${algo.model.predictor.hidden_dim}
    n_output: ${algo.model.predictor.output_dim}

  action_embed: ${algo.action_embed_dict.${envs.positions_processor}}
  
  state_embed_target: *bert_predictor
  action_embed_target: ${algo.action_embed_dict.${envs.positions_processor}}

  hyperparams:
    gamma: 0.99
    alpha: 0.01
    Lambda: 0.65
    tau: 0.1
    max_grad_norm: 2.
    accumulate_grads: ${accumulate_grads}
    action_embed_length: ${max_action_length}

  optimizer:
    _target_: torch.optim.AdamW
    lr: 5e-5            
    betas: [0.9, 0.98]  
    eps: 1e-06          
    weight_decay: 0.01  
    
  scheduler:
    _target_: rl.optim.WarmupLinearScheduler
    total: ${steps_count}
    ratio: 0.0
    warmup: 1000