import timeit
from functools import partial

import torch
from modeling.mamba import MambaModel, MambaConfig


if __name__ == "__main__":
    B = 4
    T = 2048
    L = 4
    D = 1024
    C = 256
    scan_impl = 'pscan_torch'
    act_checkpointing = 'layer'

    input_ids = torch.randint(low=0, high=8192, size=(B, T), dtype=torch.long, device='cuda')
    config = MambaConfig(d_model=D, n_layers=L, scan_impl=scan_impl)
    model = MambaModel(config, chunk_size=C).cuda()
    model.train()
    time = timeit.repeat(partial(model, input_ids, act_checkpointing=act_checkpointing), number=1)
    print(time)
    print(sum(time) / len(time))
