# @package _global_
defaults:
    # - overide /models: Encoder_v2
    # - overide /model: ShortCircuit
    - override /data@dataset: aig_dataset
    - override /loss@policy_loss: KLDivLoss

# hparams
epochs: 400
batch_size: 2048
train_value: false
train_negation_prob: 0.5
train_permutation_prob: 0.5
eval_negation_prob: 0.5
eval_permutation_prob: 0.5


#device
use_amp: False
use_accelerator: False
use_ddp: True
grad_norm_clip: 1.0

# logging
save_every: 1 # frequency to save the model
mlflow_uri: http://localhost:5000
# log_dir: Set the path in case you want to load a model
# model_file: Set the name of the model you want to continue from
step_logging: 500 # log loss after how many steps
master_rank: 0


# model
model:
  position_embeddings: True
  embedding_size: 256  # must be at least 2**num_inputs
  n_heads: 16
  n_layers: 4
  n_policy_layers: 3
  n_value_layers: 3
  intermediate_size: 4096

# data
dataloader_workers: 4
train_split: 0.9

dataset:
  aigs: 
    - data/unoptimized/8_inputs/**/*.aig # Example


  return_action_mask: True
  gamma: 0.99
  const_node: false
  reward_type: simple
  num_workers: 16
  fragments: 1


# evaluation
test_generation: False
test_frequency: 20 # How often (epochs) to test generation
test_limit: 100 # Number of truthtables to test
max_nodes: 30
AZ_workers: 4 # Number of workers to perform AZ in parallel

AZ: # AlphaZero Parameters
  num_simulations: 10
  simulation_max_steps: 20 #30 for chess
  max_steps: 20
  c_puct: 1.0
  dirichlet_alpha: 0.03
  use_value_network: false

# environment
const_node: False # include the const node in the graphs
return_action_mask: True # Filter illegal action logits before loss calculation
get_causal_mask: True # pass the causal mask instead of generating it within the model
reward_type: simple
