[wandb]
use_wandb = false

[benchmark]
model_type = "EGNN_S"
benchmark_name = "egnn_s_tg80_multitask_muon_fold2"
compile = true
compile_trace = false
runs = 3
log_weights = false

[dataloader]
multitask = true
delta_T = [3000, 24_000]
dataset = "tg80"
num_timesteps = 8

# Single-task dataloader parameters
# molecule_type = "aspirin"  # Iterates through single-task learning on these molecules

# Multitask dataloader parameters
train_molecules = [
    "1.3-cyclohexadiene",
    "1.4-dioxane",
    "2-butanone",
    "acetamide",
    "aceticacid",
    "acetonitrile",
    "aniline",
    "anthracene",
    "aspirin",
    "benzaldehyde",
    "benzene",
    "benzene1",
    "benzene2",
    "benzoicacid",
    "benzothiophene",
    "benzylamine",
    "biphenyl",
    "chlorobenzene",
    "chloroform",
    "citricacid",
    "coumarin",
    "cyclobutane",
    "cyclohexane",
    "cyclohexanol",
    "cyclohexanone",
    "cyclopentadiene",
    "cyclopentanone",
    "cyclopropane",
    "ethanol1",
    "furan",
    "furfural",
    "imidazole",
    "indole",
    "isobutane",
    "isopropanol",
    "isoquinoline",
    "malondialdehyde1",
    "malondialdehyde2",
    "malonicacid",
    "methanol",
    "naphthalene",
    "nitrobenzene",
    "oxalicacid",
    "p-cresol",
    "p-xylene",
    "paracetamol",
    "pyrimidine",
    "quinoline",
    "salicylicacid1",
    "salicylicacid2",
    "salicylicacid3",
    "styrene",
    "succinicacid",
    "tetrahydrofuran",
    "thymine",
    "toluene",
    "trimethylamine",
    "tropane1",
    "tropane3",
    "uracil",
    "uracil1",
]

validation_molecules = [
    "1.3-butadiene",
    "acetaldehyde",
    "ethylene",
    "formaldehyde",
    "formamide",
    "formicacid",
    "propylene",
]

test_molecules = [
    "1.2-dichloroethane",
    "butane",
    "butanol",
    "butylamine",
    "ethanethiol",
    "ethanol",
    "ethylamine",
    "heptanol",
    "hexanol",
    "pentanol",
    "propane",
    "tropane2",
]

# Other dataloader parameters
explicit_hydrogen = false
explicit_hydrogen_gradients = false
radius_graph_threshold = 1.6
rrwp_length = 8  # 0 for no RRWPs
time_lag_mode = "uniform"
normalize_z = false
persistent_workers = true
num_workers = 6
pin_memory = true
prefetch_factor = 2
force_regenerate = true

[training]
device = "cuda"
seed = 42
batch_size = 192
epochs = 50
use_amp = false
amp_dtype = "bfloat16"
max_grad_norm = 1.0
label_noise_std = 0.0

[optimizer]
type = "muon"
learning_rate = 0.002
weight_decay = 1e-5
adam_betas = [0.95, 0.999]
adam_eps = 1e-10

[scheduler]
type = "none"

[atom_config]
# Architecture parameters
num_layers = 8
num_heads = 8
lifting_dim = 256
# Output parameters
output_heads = 1
delta_update = false
# Attention parameters
heterogenous_attention_type = "ghca"
positional_encoding = "trope"
rope_base = 1000
rope_tau = 24_000
learnable_attention_denom = false
# Feature parameters
lifting_type = "quasi_equivariant"
projection_type = "equivariant"
# Layer parameters
norm = "rms"
activation = "swiglu"
value_residual_type = "learnable"

[egno_config]
num_layers = 8
lifting_dim = 64
activation = "silu"
normalise_scalars = true
use_time_conv = true
num_fourier_modes = 2
time_embed_dim = 32

[egnn_config]
num_layers = 8
lifting_dim = 64
activation = "silu"
time_embed_dim = 32