from collections import OrderedDict
import argparse

def get_default_config():
    return OrderedDict(

        general=OrderedDict(
            cg_level = "low", # "high" or "low"
            t0=0.0,
            t1=1.0,
            mu = 0.0,
            sigma = 1.0,
        ),

        dataset=OrderedDict(
            mb_datafile="",
            ala2_datafile="",
            num_samples=50_000,
            shuffle=False,
            drop_last=True,
        ),
    
        model=OrderedDict(

            GraphTransformer=OrderedDict(
                hidden_layers=128,
                embedding_layers=16,
                max_z=[10],  # Maximum number of different atom types
                n_layers=3,
                rescale_time=False,
                clip_time=False,
                use_intrinsic_coords=True,
                use_abs_coords=True,
                use_distances=True,
                dropout=0,
                feat_type="distinguish",
            ),

            MLP=OrderedDict(
                hidden_layers=96,
                embedding_layers=16,
                n_layers=3,
            ),
        ),

        trainer=OrderedDict(
            name="adamw",
            learning_rate=3e-4,
            min_learning_rate=1e-5,
            clip=1e3,
            schedule="cosine",
            weight_decay=1e-5,
            epochs=5000,
            batch=256,
        ),

        simulator=OrderedDict(
            method= "dopri5", # "dopri5" or "tsit5"
            num_batches= 1000,
            batch= 1000,
            dt0 = 5e-3,
            num_z = 0,
            mean = True,
        ),
    )