
import random
from typing import Dict, Literal
from tap import Tap

class Args(Tap):
    no_wandb: bool = False
    total_timesteps: int = 100_000_000
    seed: int = random.randint(0, 2**32 - 1)
    n_envs: int = 32
    device: Literal['cpu', 'mps', 'cuda'] = 'cuda'
    feature_extractor: Literal['none', 'tfh_small', 'tfh_big', 'tf', 'tfh_fast'] = 'tfh_fast'
    env_num_inserts: int = 6
    env_num_deletes: int = 6
    env_max_tree_values: int = 24
    env_max_values_per_node: int = 4
    checkpoint_callback_freq: int = 50_000
    s_net_arch: Dict[str, list[int]] = {'pi': [512, 512], 'vf': [512, 512]}
    s_transformer_features_dim: int = 64
    s_transformer_num_layers: int = 2
    s_transformer_nhead: int = 2
    s_n_epochs: int = 10
    s_learning_rate: float = 1e-4
    s_entropy_coef: int = 0
    s_n_eval_episodes: int = 10_000
    s_eval_freq: int = 50_000
    s_n_seeds: int = 40
    s_batch_size: int = 512
    s_n_steps: int =  1000 


if __name__ == '__main__':
    args = Args().parse_args()
    print(args)

