# REINVENT4 TOML input example for reinforcement/curriculum learning
#
#
# Curriculum learning in REINVENT4 is a multi-stage reinforcement learning
# run.  One or more stages (auto CL) can be defined.  But it is also
# possible to continue a run from any checkpoint file that is generated
# during the run (manual CL).  Currently checkpoints are written at the end
# of a run also when the run is forcefully terminated with Ctrl-C.


run_type = "staged_learning"
device = "cuda:0"  # set torch device e.g. "cpu"
tb_logdir = "tb_logs"  # name of the TensorBoard logging directory
json_out_config = "staged_learning.json"  # write this TOML to JSON

[parameters]

# Uncomment one of the comment blocks below.  Each generator needs a model
# file and possibly a SMILES file with seed structures.  If the run is to
# be continued after termination, the agent_file would have to be replaced
# with the checkpoint file.

summary_csv_prefix = "staged_learning"  # prefix for the CSV file
use_checkpoint = false  # if true read diversity filter from agent_file
purge_memories = false  # if true purge all diversity filter memories after each stage

## Reinvent
prior_file = "./reinvent.prior"
agent_file = "./reinvent.prior"

## LibInvent
#prior_file = "priors/libinvent.prior"
#agent_file = "priors/libinvent.prior"
#smiles_file = "<scaffold>.smi"  # 1 scaffold per line with attachment points

## LinkInvent
#prior_file = "priors/linkinvent.prior"
#agent_file = "priors/linkinvent.prior"
#smiles_file = "<scaffold>.smi"  # 2 warheads per line separated with '|'

## Mol2Mol
#prior_file = "priors/mol2mol_similarity.prior"
#agent_file = "priors/mol2mol_similarity.prior"
#smiles_file = "<scaffold>.smi"  # 1 compound per line
#sample_strategy = "multinomial"  # multinomial or beamsearch (deterministic)
#distance_threshold = 100

batch_size = 64          # network

unique_sequences = true  # if true remove all duplicates raw sequences in each step
                         # only here for backward compatibility
randomize_smiles = false  # if true shuffle atoms in SMILES randomly


[learning_strategy]

type = "dap"      # dap: only one supported
sigma = 128       # sigma of the RL reward function
rate = 0.0005     # for torch.optim


[diversity_filter]  # optional, comment section out or remove if unneeded
                    # NOTE: also memorizes all seen SMILES

type = "PenalizeSameSmiles" # IdenticalTopologicalScaffold,
                                 # ScaffoldSimilarity, PenalizeSameSmiles
bucket_size = 25                 # memory size in number of compounds
minscore = 0.4                   # only memorize if this threshold is exceeded
minsimilarity = 0.4              # minimum similarity for ScaffoldSimilarity
penalty_multiplier = 0.5         # penalty factor for PenalizeSameSmiles


# Reinvent only: guide RL in the initial phase
#[inception]  # optional, comment sectionout or remove if unneeded

#smiles_file = "sampled.smi"  # "good" SMILES for guidance
#memory_size = 100  # number of total SMILES held in memory
#sample_size = 10  # number of SMILES randomly chosen each epoch

### Stage 1
[[stage]]

chkpt_file = 'staged_learning.chkpt'  # name of the checkpoint file, can be reused as agent

termination = "simple"  # termination criterion fot this stage
max_score = 0.99  # terminate if this total score is exceeded
min_steps = 25  # run for at least this number of steps
max_steps = 1000  # terminate entire run when exceeded


[stage.scoring]
type = "geometric_mean"  # aggregation function

[[stage.scoring.component]]
[stage.scoring.component.MolecularWeight]
[[stage.scoring.component.MolecularWeight.endpoint]]
name = "MolWt"
weight = 2
transform.type = "double_sigmoid"
transform.high = 700
transform.low = 200
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.SlogP]
[[stage.scoring.component.SlogP.endpoint]]
name = "MolLogP"
weight = 2
transform.type = "double_sigmoid"
transform.high = 8
transform.low = 2
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.NumRotBond]
[[stage.scoring.component.NumRotBond.endpoint]]
name = "NumRotatableBonds"
weight = 1
transform.type = "double_sigmoid"
transform.high = 14
transform.low = -1
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.TPSA]
[[stage.scoring.component.TPSA.endpoint]]
name = "CalcTPSA"
weight = 1
transform.type = "double_sigmoid"
transform.high = 135
transform.low = -5
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.NumRings]
[[stage.scoring.component.NumRings.endpoint]]
name = "RingCount"
weight = 2
transform.type = "double_sigmoid"
transform.high = 7
transform.low = 2
transform.coef_div = 15.0
transform.coef_si = 30.0
transform.coef_se = 30.0

[[stage.scoring.component]]
[stage.scoring.component.NumHeavyAtoms]
[[stage.scoring.component.NumHeavyAtoms.endpoint]]
name = "GetNumHeavyAtoms"
weight = 1
transform.type = "double_sigmoid"
transform.high = 50
transform.low = 15
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.NumAromaticRings]
[[stage.scoring.component.NumAromaticRings.endpoint]]
name = "CalcNumAromaticRings"
weight = 1
transform.type = "double_sigmoid"
transform.high = 5
transform.low = 1
transform.coef_div = 9.0
transform.coef_si = 40.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.Csp3]
[[stage.scoring.component.Csp3.endpoint]]
name = "CalcFractionCSP3"
weight = 1
transform.type = "double_sigmoid"
transform.high = 0.65
transform.low = -0.05
transform.coef_div = 2.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.HBondAcceptors]
[[stage.scoring.component.HBondAcceptors.endpoint]]
name = "NumHAcceptors"
weight = 1
transform.type = "double_sigmoid"
transform.high = 8.18
transform.low = 0.58
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0

[[stage.scoring.component]]
[stage.scoring.component.HBondDonors]
[[stage.scoring.component.HBondDonors.endpoint]]
name = "NumHDonors"
weight = 1
transform.type = "double_sigmoid"
transform.high = 4.16
transform.low = -0.36
transform.coef_div = 9.0
transform.coef_si = 20.0
transform.coef_se = 20.0