[wandb]
use_wandb = false

[benchmark]
model_type = "atom"
benchmark_name = "atom_tg80_multitask_muon_fold5"
compile = true
compile_trace = false
runs = 3
log_weights = false

[dataloader]
multitask = true
delta_T = [8, 24_000]
dataset = "tg80"
num_timesteps = 8

# Single-task dataloader parameters
# molecule_type = "aspirin"  # Iterates through single-task learning on these molecules

# Multitask dataloader parameters
train_molecules = [
    "acetamide",
    "aceticacid",
    "citricacid",
    "malonicacid",
    "oxalicacid",
    "salicylicacid1",
    "salicylicacid2",
    "succinicacid",
]

validation_molecules = [
    "acetamide",
    "aceticacid",
    "citricacid",
    "malonicacid",
    "oxalicacid",
    "salicylicacid1",
    "salicylicacid2",
    "succinicacid",
]

test_molecules = [
    "acetamide",
    "aceticacid",
    "citricacid",
    "malonicacid",
    "oxalicacid",
    "salicylicacid1",
    "salicylicacid2",
    "succinicacid",
]

# Other dataloader parameters
explicit_hydrogen = false
explicit_hydrogen_gradients = false
radius_graph_threshold = 1.6
rrwp_length = 8  # 0 for no RRWPs
time_lag_mode = "uniform"
normalize_z = false
persistent_workers = true
num_workers = 6
pin_memory = true
prefetch_factor = 2
force_regenerate = true

[training]
device = "cuda"
seed = 42
batch_size = 192
epochs = 50
use_amp = true
amp_dtype = "bfloat16"
max_grad_norm = 1.0
label_noise_std = 0.0

[optimizer]
type = "muon"
learning_rate = 0.002
weight_decay = 1e-5
adam_betas = [0.95, 0.999]
adam_eps = 1e-10

[scheduler]
type = "none"

[atom_config]
# Architecture parameters
num_layers = 6
num_heads = 8
lifting_dim = 256
# Output parameters
output_heads = 1
delta_update = false
# Attention parameters
heterogenous_attention_type = "ghca"
positional_encoding = "trope"
rope_base = 1000
rope_tau = 24_000
learnable_attention_denom = false
# Feature parameters
lifting_type = "quasi_equivariant"
projection_type = "equivariant"
# Layer parameters
norm = "rms"
activation = "swiglu"
value_residual_type = "learnable"

[egno_config]
num_layers = 5
lifting_dim = 64
activation = "silu"
normalise_scalars = true
use_time_conv = true
num_fourier_modes = 2
time_embed_dim = 32