import os
os.environ['OMP_NUM_THREADS']='1'
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'  # allow CPU fallback for unsupported MPS ops (e.g., Dirichlet)

from train_eval import train_and_eval
import json

if __name__ == "__main__":
    config = dict(
        batch_size=16,
        val_ratio=0.15,
        test_ratio=0.15,
        num_tasks=3,
        seq_len=36,
        num_seq_per_task=100,
        d=6,
        K=4,
        T=48,
        top_k=2,
        epochs=200,
        lr=3e-4,
        weight_decay=1e-5,
        lambda_ortho=1e-4,
        lambda_gate_align=0.1,
        progress_bar=True,
        sanity_every=10,  # set 0 to disable; prints val metrics every N epochs
        seed=123
    )
    results = train_and_eval(config)
    print(json.dumps(results, indent=2))
