import torch
from transformers.feature_extraction_utils import BatchFeature
def action_head_pytorch_forward(self, backbone_output, action_input):
    backbone_output = self.process_backbone_output(backbone_output)
    vl_embs = backbone_output.backbone_features
    embodiment_id = action_input.embodiment_id
    state_features = self.state_encoder(action_input.state, embodiment_id)
    batch_size = vl_embs.shape[0]
    device = vl_embs.device
    actions = torch.randn(
        size=(batch_size, self.config.action_horizon, self.config.action_dim),
        dtype=vl_embs.dtype,
        device=device,
    )
    if hasattr(self, "init_actions"):
        actions = self.init_actions.expand((batch_size, -1, -1))
    num_steps = self.num_inference_timesteps
    dt = 1.0 / num_steps
    for t in range(num_steps):
        t_cont = t / float(num_steps)
        t_discretized = int(t_cont * self.num_timestep_buckets)
        timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
        action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
        if self.config.add_pos_embed:
            pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
            pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
            action_features = action_features + pos_embs
        future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
        sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
        model_output = self.model(
            hidden_states=sa_embs,
            encoder_hidden_states=vl_embs,
            timestep=timesteps_tensor,
        )
        pred = self.action_decoder(model_output, embodiment_id)
        pred_velocity = pred[:, -self.action_horizon :]
        actions = actions + dt * pred_velocity
    return BatchFeature(data={"action_pred": actions})
