from tap import Tap


class Args(Tap):
    debug: int = 0  # If 1, will print a lot of debug info.

    # Model
    model: str = 'mamba2'  # One of ['mamba', 'rwkv5', 'mamba2']
    model_config: str = "configs/model/mamba_130m.json"  # Only used when model="mamba".
    dtype: str = 'bf16'

    # Tokenizer
    tok_path: str = "./tokenizers/mamba-tok"

    # Checkpoint
    pretrained_path: str = "../../ckpts/mamba/mamba2-370m"
    '''
    The path to the state dict of the pretrained checkpoint.
    '''
    rand_init: int = 0
    do_convert_checkpoint: int = 1  # Whether to convert keys in the checkpoint.
    resume_path: str = "none"  # The directory to resume from.
    load_grad: int = 0
    grad_ckpt_num: int = (
        0  # grad file num (only work when --load-grad from files less than world-size )
    )
    load_start_step: int = 0  # Whether to skip to the step the checkpoint was saved at.
    output_dir: str = './output'  # The output directory to save checkpoints to.
    save_interval: int = 100

    # Logging
    tensorboard: str = "./tensorboard"  # Directory of tensorboard
    log_interval: int = 1
    inspect_interval: int = 1000  # Number of batches between each parameter inspection.
    project_name: str = 'LongRNN'

    # Training
    stop_when_end: int = 1
    clip_grad: float = 1.0
    grad_accum: int = 4
    n_train_steps: int = 100000
    max_length: int = 8192
    min_length: int = -1  # Not used yet.
    seed: int = 0
    weight_decay: float = 1e-1
    chunk_size: int = 128
    state_reset_interval: int = 1  # The number of steps between each state reset.
    reset_interval_warmup_steps: int = 1  # The number of optimization steps where the state_reset_interval is gradually increased.
    grad_ckpt: str = 'layer'
    scan_impl: str = 'seq_triton'

    # Optimizer
    offload: int = 0  # Not working, whether to use offload adam.

    # Loss
    loss_scale: float = 64 * 1024
    max_loss_scale: float = float("inf")
    min_loss_scale: float = 1.0
    loss_scale_steps: float = 1024

    # LR scheduler
    lr: float = 1e-4
    lr_scheduler: str = "stabledrop"
    n_warmup_steps: int = 100
    n_drop_steps: int = 100
    resume_no_optimize: int = 0

    # Resumption
    start_step: int = 0

    # Data config
    dataset: str = "dataset.json"

    # Model config
    eps: float = 1e-5

    # Data format
    # data_config: str = 'configs/data/slimpajama.json'
    data_config: str = 'configs/data/redpajama_4k.json'
    one_sequence_batch: int = 0
    packing_count: int = 1  # The number of sequence to pack together (concatenate)
    batch_size: int = 2  # The number of sequences in each batch

    # Data Loading
    dataloader_num_threads: int = 3
    dataloader_prefetch: int = 200
    dataloader_num_workers: int = 4
    dataloader_prefetch_factor: int = 50
    dataloader: str = "indexed"
    data_len_threshold: int = 512
    only_run_dataloader: int = 0
    only_load_model: int = 0
    load_dataloader_ckpt: int = 0
    parallel_load_datastate: int = 256
    repeat_data: int = 0
    device: str = 'cuda'
