[wandb]
use_wandb = false

[benchmark]
model_type = "atom"
benchmark_name = "atom_tg80_multitask_muon_fold1_12moe"
compile = true
compile_trace = false
runs = 1
log_weights = false

[dataloader]
multitask = true
delta_T = 10_000
dataset = "tg80"
num_timesteps = 8

# Single-task dataloader parameters
# molecule_type = "aspirin"  # Iterates through single-task learning on these molecules

# Multitask dataloader parameters
train_molecules = [
    "isopropanol",
    "benzoicacid",
    "cyclohexanol",
    "ethylamine",
    "heptanol",
    "butanol",
    "butylamine",
    "benzaldehyde",
    "formamide",
    "anthracene",
    "butane",
    "p-cresol",
    "ethanethiol",
    "oxalicacid",
    "indole",
    "aniline",
    "1.2-dichloroethane",
    "malonicacid",
    "salicylicacid2",
    "malondialdehyde1",
    "ethylene",
    "furfural",
    "acetamide",
    "quinoline",
    "furan",
    "methanol",
    "tropane2",
    "benzene",
    "tropane3",
    "pentanol",
    "cyclohexanone",
    "formaldehyde",
    "benzothiophene",
    "succinicacid",
    "propane",
    "chlorobenzene",
    "acetaldehyde",
]

validation_molecules = [
    "benzylamine",
    "nitrobenzene",
    "uracil1",
    "1.3-butadiene",
    "trimethylamine",
    "acetonitrile",
    "naphthalene",
    "coumarin",
]

test_molecules = [
    "2-butanone",
    "cyclopropane",
    "propylene",
    "uracil",
    "benzene2",
    "cyclohexane",
    "biphenyl",
    "1.3-cyclohexadiene",
]


# Other dataloader parameters
explicit_hydrogen = false
explicit_hydrogen_gradients = false
radius_graph_threshold = 1.6
rrwp_length = 8  # 0 for no RRWPs

normalize_z = false
persistent_workers = true
num_workers = 6
pin_memory = true
prefetch_factor = 2
force_regenerate = true

[training]
device = "cuda"
seed = 42
batch_size = 192
epochs = 250
use_amp = true
amp_dtype = "bfloat16"
max_grad_norm = 1.0
label_noise_std = 0.1

[optimizer]
type = "muon"
learning_rate = 1e-3
weight_decay = 1e-5
adam_betas = [0.9, 0.999]
adam_eps = 1e-10

[scheduler]
type = "none"

[atom_config]
# Architecture parameters
num_layers = 5
num_heads = 8
lifting_dim = 128
# Output parameters
output_heads = 12
delta_update = false
# Attention parameters
heterogenous_attention_type = "ghca"
positional_encoding = "trope"
rope_base = 1000
rope_tau = 10000
learnable_attention_denom = false
# Feature parameters
lifting_type = "quasi_equivariant_tp"
projection_type = "equivariant"
# Layer parameters
norm = "rms"
activation = "swiglu"
value_residual_type = "learnable"

[egno_config]
num_layers = 5
lifting_dim = 64
activation = "silu"
normalise_scalars = true
use_time_conv = true
num_fourier_modes = 2
time_embed_dim = 32