from torch import Tensor
import transformers
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

def llama_ckpt_mlp_forward(
    self: "transformers.models.llama.modeling_llama.LlamaMLP",
    x: Tensor,
    compress_kwargs: dict | None = None,
) -> Tensor:
    def mlp_intermediate(x_):
        return self.act_fn(self.gate_proj(x_)) * self.up_proj(x_)

    intermediate = checkpoint(mlp_intermediate, x, use_reentrant=True)
    out = self.down_proj(intermediate)
    return out