from typing import Optional
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
from torch import Tensor

def checkpointed_forward(
    module: nn.Module, 
    use_checkpointing: Optional[bool|int],
    *args, 
    **kwargs,
    ):
    if use_checkpointing:
        return checkpoint(module, *args, use_reentrant=False, **kwargs)
    return module(*args, **kwargs)
    
            # for block in blocks:
            # x = checkpointed_forward(
            #     block,
            #     x,
            #     emb,
            #     use_checkpointing=use_checkpointing,
            # )