import dataclasses
import logging

import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override

from openpi.models import model as _model
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
from openpi.models.transformer import TransformerBlock

logger = logging.getLogger("openpi")


def make_attn_mask(input_mask, mask_ar):
    """Adapted from big_vision.

    Tokens can attend to valid inputs tokens which have a cumulative mask_ar
    smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
    setup several types of attention, for example:

      [[1 1 1 1 1 1]]: pure causal attention.

      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
          themselves and the last 3 tokens have a causal attention. The first
          entry could also be a 1 without changing behaviour.

      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
          block can attend all previous blocks and all tokens on the same block.

    Args:
      input_mask: bool[B, N] true if its part of the input, false if padding.
      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
        it and false where it shares the same attention mask as the previous token.
    """
    mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
    cumsum = jnp.cumsum(mask_ar, axis=1)
    attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
    return jnp.logical_and(attn_mask, valid_mask)


@at.typecheck
def posemb_sincos(
    pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
    """Computes sine-cosine positional embedding vectors for scalar positions."""
    if embedding_dim % 2 != 0:
        raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")

    fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
    period = min_period * (max_period / min_period) ** fraction
    sinusoid_input = jnp.einsum(
        "i,j->ij",
        pos,
        1.0 / period * 2 * jnp.pi,
        precision=jax.lax.Precision.HIGHEST,
    )
    return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)


@dataclasses.dataclass(frozen=True)
class Pi0PredConfig(_model.BaseModelConfig):
    dtype: str = "bfloat16"
    paligemma_variant: _gemma.Variant = "gemma_2b"
    action_expert_variant: _gemma.Variant = "gemma_300m"

    # Set the model specific defaults.
    action_dim: int = 32
    action_horizon: int = 50
    max_token_len: int = 48
    hidden_dim: int = 2048
    # alpha: float = 1e-5
    hidden_idx: int = 9
    hidden_mode: str = "token_mean"
    hidden_input_idx: int = 17
    use_transformer: bool = False

    @property
    @override
    def model_type(self) -> _model.ModelType:
        return _model.ModelType.PI0_PRED

    @override
    def create(self, rng: at.KeyArrayLike) -> "Pi0Pred":
        return Pi0Pred(self, rngs=nnx.Rngs(rng))

    @override
    def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
        image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
        image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)

        with at.disable_typechecking():
            observation_spec = _model.Observation(
                images={
                    "base_0_rgb": image_spec,
                    "left_wrist_0_rgb": image_spec,
                    "right_wrist_0_rgb": image_spec,
                },
                image_masks={
                    "base_0_rgb": image_mask_spec,
                    "left_wrist_0_rgb": image_mask_spec,
                    "right_wrist_0_rgb": image_mask_spec,
                },
                state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
                tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
                tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
            )
        action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)

        return observation_spec, action_spec

    def get_freeze_filter(self) -> nnx.filterlib.Filter:
        """Returns the freeze filter based on the model config."""
        filters = []
        has_lora = False
        gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
        action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
        if "lora" in self.paligemma_variant:
            filters.append(
                gemma_params_filter,
            )
            if "lora" not in self.action_expert_variant:
                # If only freeze gemma params, exclude action expert params.
                filters.append(
                    nnx.Not(action_expert_params_filter),
                )
            has_lora = True
        elif "lora" in self.action_expert_variant:
            filters.append(
                action_expert_params_filter,
            )
            has_lora = True

        if has_lora:
            # If any lora is used, exclude all lora params.
            filters.append(
                nnx.Not(nnx_utils.PathRegex(".*lora.*")),
            )
        if not filters:
            return nnx.Nothing
        return nnx.All(*filters)


class Pi0Pred(_model.BaseModel):
    def __init__(self, config: Pi0PredConfig, rngs: nnx.Rngs):
        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
        paligemma_config = _gemma.get_config(config.paligemma_variant)
        action_expert_config = _gemma.get_config(config.action_expert_variant)
        # TODO: rewrite gemma in NNX. For now, use bridge.
        llm = nnx_bridge.ToNNX(
            _gemma.Module(
                configs=[paligemma_config, action_expert_config],
                embed_dtype=config.dtype,
            )
        )
        llm.lazy_init(rngs=rngs, method="init")
        img = nnx_bridge.ToNNX(
            _siglip.Module(
                num_classes=paligemma_config.width,
                variant="So400m/14",
                pool_type="none",
                scan=True,
                dtype_mm=config.dtype,
            )
        )
        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
        self.PaliGemma = nnx.Dict(llm=llm, img=img)
        self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
        self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
        self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
        self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
        self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
        self.use_transformer = config.use_transformer
        if self.use_transformer:
            self.hidden_pred = TransformerBlock(config.hidden_dim, 8, 5, rngs=rngs)
        else:
            self.hidden_pred = nnx.Linear(config.hidden_dim, config.hidden_dim, rngs=rngs)
        self.hidden_idx = config.hidden_idx
        self.hidden_mode = config.hidden_mode
        self.hidden_input_idx = config.hidden_input_idx
        # self.alpha = config.alpha
    @at.typecheck
    def embed_prefix(
        self, obs: _model.Observation
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        input_mask = []
        ar_mask = []
        tokens = []
        # embed images
        # jax.debug.print("{} ------------", obs.images.keys())
        for name in obs.images:
            image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)

            tokens.append(image_tokens)
            input_mask.append(
                einops.repeat(
                    obs.image_masks[name],
                    "b -> b s",
                    s=image_tokens.shape[1],
                )
            )
            # image tokens attend to each other
            ar_mask += [False] * image_tokens.shape[1]

        # add language (aka tokenized inputs)
        if obs.tokenized_prompt is not None:
            tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
            tokens.append(tokenized_inputs)
            input_mask.append(obs.tokenized_prompt_mask)
            # full attention between image and language inputs
            ar_mask += [False] * tokenized_inputs.shape[1]
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    @at.typecheck
    def embed_suffix(
        self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        input_mask = []
        ar_mask = []
        tokens = []
        # add a single state token
        # state_token = self.state_proj(obs.state)[:, None, :]
        # tokens.append(state_token)
        # input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
        # # image/language inputs do not attend to state or actions
        # ar_mask += [True]

        # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
        time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
        # mix timestep + action information using an MLP
        action_tokens = self.action_in_proj(noisy_actions)
        time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
        action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
        action_time_tokens = self.action_time_mlp_in(action_time_tokens)
        action_time_tokens = nnx.swish(action_time_tokens)
        action_time_tokens = self.action_time_mlp_out(action_time_tokens)
        tokens.append(action_time_tokens)
        input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
        # image/language/state inputs do not attend to action tokens
        ar_mask += [True] + ([False] * (self.action_horizon - 1))
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    @override
    def compute_loss(
        self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
    ) -> at.Float[at.Array, "*b ah"]:
        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
        CURR_IMAGE_KEYS = (
            "base_0_rgb",
            "left_wrist_0_rgb",
            "right_wrist_0_rgb",
        )
        NEXT_IMAGE_KEYS = (
            "next_base_0_rgb",
            "next_left_wrist_0_rgb",
            "next_right_wrist_0_rgb",
        )
        curr_observation = _model.preprocess_observation(preprocess_rng, observation, train=train, image_keys = CURR_IMAGE_KEYS)
        next_observation = _model.preprocess_observation(preprocess_rng, observation, train=train, image_keys = NEXT_IMAGE_KEYS)

        batch_shape = actions.shape[:-2]
        noise = jax.random.normal(noise_rng, actions.shape)
        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
        time_expanded = time[..., None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

        # one big forward pass of prefix + suffix at once
        curr_prefix_tokens, curr_prefix_mask, curr_prefix_ar_mask = self.embed_prefix(curr_observation)
        next_prefix_tokens, next_prefix_mask, next_prefix_ar_mask = self.embed_prefix(next_observation)
        curr_suffix_tokens, curr_suffix_mask, curr_suffix_ar_mask = self.embed_suffix(curr_observation, x_t, time)
        next_suffix_tokens, next_suffix_mask, next_suffix_ar_mask = self.embed_suffix(next_observation, x_t, time)

        curr_input_mask = jnp.concatenate([curr_prefix_mask, curr_suffix_mask], axis=1)
        curr_ar_mask = jnp.concatenate([curr_prefix_ar_mask, curr_suffix_ar_mask], axis=0)
        curr_attn_mask = make_attn_mask(curr_input_mask, curr_ar_mask)
        curr_positions = jnp.cumsum(curr_input_mask, axis=1) - 1

        (curr_prefix_out, curr_suffix_out), kv_cache, all_xs, all_masks = self.PaliGemma.llm(
            [curr_prefix_tokens, curr_suffix_tokens], mask=curr_attn_mask, positions=curr_positions
        )
        curr_hidden_idx = self.hidden_input_idx
        next_hidden_idx = self.hidden_idx
        if self.hidden_mode == "token_mean":
            curr_hidden_mask = all_masks[0][curr_hidden_idx]
            curr_hidden = all_xs[0][curr_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
            mask_expanded = curr_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
            curr_hidden = curr_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)


            v_t = self.action_out_proj(curr_suffix_out[:, -self.action_horizon :])

            next_input_mask = jnp.concatenate([next_prefix_mask, next_suffix_mask], axis=1)
            next_ar_mask = jnp.concatenate([next_prefix_ar_mask, next_suffix_ar_mask], axis=0)
            next_attn_mask = make_attn_mask(next_input_mask, next_ar_mask)
            next_positions = jnp.cumsum(next_input_mask, axis=1) - 1

            (next_prefix_out, next_suffix_out), kv_cache, all_xs, all_masks = self.PaliGemma.llm(
                [next_prefix_tokens, next_suffix_tokens], mask=next_attn_mask, positions=next_positions
            )

            next_hidden_mask = all_masks[0][next_hidden_idx]
            next_hidden = all_xs[0][next_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
            mask_expanded = next_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
            next_hidden = next_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)

            next_hidden = jax.lax.stop_gradient(next_hidden)

            # token mean
            if self.use_transformer:
                hidden_input = jnp.expand_dims(curr_hidden, axis=1)
                hidden_output = self.hidden_pred(hidden_input)
                next_hidden_pred = jnp.squeeze(hidden_output, axis=1)
                # 计算每个token的loss，然后在token维度上取平均
                token_loss = jnp.square(next_hidden_pred - next_hidden)  # shape: (batch, seq_len, hidden_dim)
                token_loss_mean = jnp.mean(token_loss, axis=-1)  # shape: (batch, seq_len) - 在hidden_dim维度求均值
                # 只对mask为True的token计算loss，然后在token维度取平均
                masked_token_loss = token_loss_mean * curr_hidden_mask  # shape: (batch, seq_len)
                valid_token_count = jnp.sum(curr_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
                hidden_pred_loss = jnp.where(
                    valid_token_count > 0, 
                    jnp.sum(masked_token_loss, axis=1) / jnp.squeeze(valid_token_count, axis=1), 
                    0.0
                )  # shape: (batch,)

                return jnp.mean(jnp.square(v_t - u_t), axis=-1), hidden_pred_loss
            else:
                next_hidden_pred = self.hidden_pred(curr_hidden)  # shape: (batch, seq_len, hidden_dim)
                # 计算每个token的loss，然后在token维度上取平均
                token_loss = jnp.square(next_hidden_pred - next_hidden)  # shape: (batch, seq_len, hidden_dim)
                token_loss_mean = jnp.mean(token_loss, axis=-1)  # shape: (batch, seq_len) - 在hidden_dim维度求均值
                # 只对mask为True的token计算loss，然后在token维度取平均
                masked_token_loss = token_loss_mean * curr_hidden_mask  # shape: (batch, seq_len)
                valid_token_count = jnp.sum(curr_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
                hidden_pred_loss = jnp.where(
                    valid_token_count > 0, 
                    jnp.sum(masked_token_loss, axis=1) / jnp.squeeze(valid_token_count, axis=1), 
                    0.0
                )  # shape: (batch,)

                return jnp.mean(jnp.square(v_t - u_t), axis=-1), hidden_pred_loss
        elif self.hidden_mode == "last_token":
            # last token
            curr_hidden_mask = all_masks[0][curr_hidden_idx]
            last_true_idx = jnp.argmax(curr_hidden_mask[:, ::-1], axis=1)
            last_true_idx = curr_hidden_mask.shape[1] - 1 - last_true_idx
            batch_indices = jnp.arange(curr_hidden_mask.shape[0])
            curr_hidden = all_xs[0][curr_hidden_idx] [batch_indices, last_true_idx]
            v_t = self.action_out_proj(curr_suffix_out[:, -self.action_horizon :])

            next_input_mask = jnp.concatenate([next_prefix_mask, next_suffix_mask], axis=1)
            next_ar_mask = jnp.concatenate([next_prefix_ar_mask, next_suffix_ar_mask], axis=0)
            next_attn_mask = make_attn_mask(next_input_mask, next_ar_mask)
            next_positions = jnp.cumsum(next_input_mask, axis=1) - 1

            (next_prefix_out, next_suffix_out), kv_cache, all_xs, all_masks = self.PaliGemma.llm(
                [next_prefix_tokens, next_suffix_tokens], mask=next_attn_mask, positions=next_positions
            )

            # last token
            next_hidden_mask = all_masks[0][next_hidden_idx]
            last_true_idx = jnp.argmax(next_hidden_mask[:, ::-1], axis=1)
            last_true_idx = next_hidden_mask.shape[1] - 1 - last_true_idx
            batch_indices = jnp.arange(next_hidden_mask.shape[0])
            next_hidden = all_xs[0][next_hidden_idx] [batch_indices, last_true_idx]
            next_hidden = jax.lax.stop_gradient(next_hidden)
            # mean/last token
            next_hidden_pred = self.hidden_pred(curr_hidden)
            hidden_pred_loss = jnp.mean(jnp.square(next_hidden_pred - next_hidden), axis=-1)
            return jnp.mean(jnp.square(v_t - u_t), axis=-1), hidden_pred_loss
        elif self.hidden_mode == "mean_token":
            # mean token
            curr_hidden_mask = all_masks[0][curr_hidden_idx]
            curr_hidden = all_xs[0][curr_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
            mask_expanded = curr_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
            masked_hidden = curr_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)
            sum_hidden = jnp.sum(masked_hidden, axis=1)  # shape: (batch, hidden_dim)
            count_true = jnp.sum(curr_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
            curr_hidden = jnp.where(count_true > 0, sum_hidden / count_true, 0.0)
        
            v_t = self.action_out_proj(curr_suffix_out[:, -self.action_horizon :])

            next_input_mask = jnp.concatenate([next_prefix_mask, next_suffix_mask], axis=1)
            next_ar_mask = jnp.concatenate([next_prefix_ar_mask, next_suffix_ar_mask], axis=0)
            next_attn_mask = make_attn_mask(next_input_mask, next_ar_mask)
            next_positions = jnp.cumsum(next_input_mask, axis=1) - 1

            (next_prefix_out, next_suffix_out), kv_cache, all_xs, all_masks = self.PaliGemma.llm(
                [next_prefix_tokens, next_suffix_tokens], mask=next_attn_mask, positions=next_positions
            )

            # mean token
            next_hidden_mask = all_masks[0][next_hidden_idx]
            next_hidden = all_xs[0][next_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
            mask_expanded = next_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
            masked_hidden = next_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)
            sum_hidden = jnp.sum(masked_hidden, axis=1)  # shape: (batch, hidden_dim)
            count_true = jnp.sum(next_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
            next_hidden = jnp.where(count_true > 0, sum_hidden / count_true, 0.0)

            next_hidden = jax.lax.stop_gradient(next_hidden)

            # mean/last token
            next_hidden_pred = self.hidden_pred(curr_hidden)
            hidden_pred_loss = jnp.mean(jnp.square(next_hidden_pred - next_hidden), axis=-1)
            return jnp.mean(jnp.square(v_t - u_t), axis=-1), hidden_pred_loss
        # # get current last(17th) hidden
        # # curr_hidden_idx = 17
        # # # last token
        # # curr_hidden_mask = all_masks[0][curr_hidden_idx]
        # # last_true_idx = jnp.argmax(curr_hidden_mask[:, ::-1], axis=1)
        # # last_true_idx = curr_hidden_mask.shape[1] - 1 - last_true_idx
        # # batch_indices = jnp.arange(curr_hidden_mask.shape[0])
        # # curr_hidden = all_xs[0][curr_hidden_idx] [batch_indices, last_true_idx]
        # # # jax.debug.print("last_true_idx {} batch_indices {}: ---------", last_true_idx, batch_indices)

        # # get current last(17th) hidden
        # # curr_hidden_idx = 17
        # # # mean token
        # # curr_hidden_mask = all_masks[0][curr_hidden_idx]
        # # curr_hidden = all_xs[0][curr_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
        # # mask_expanded = curr_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
        # # masked_hidden = curr_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)
        # # sum_hidden = jnp.sum(masked_hidden, axis=1)  # shape: (batch, hidden_dim)
        # # count_true = jnp.sum(curr_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
        # # curr_hidden = jnp.where(count_true > 0, sum_hidden / count_true, 0.0)

        # # get current last(17th) hidden
        # curr_hidden_idx = 17
        # # token mean
        # curr_hidden_mask = all_masks[0][curr_hidden_idx]
        # curr_hidden = all_xs[0][curr_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
        # mask_expanded = curr_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
        # curr_hidden = curr_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)


        # v_t = self.action_out_proj(curr_suffix_out[:, -self.action_horizon :])

        # next_input_mask = jnp.concatenate([next_prefix_mask, next_suffix_mask], axis=1)
        # next_ar_mask = jnp.concatenate([next_prefix_ar_mask, next_suffix_ar_mask], axis=0)
        # next_attn_mask = make_attn_mask(next_input_mask, next_ar_mask)
        # next_positions = jnp.cumsum(next_input_mask, axis=1) - 1

        # (next_prefix_out, next_suffix_out), kv_cache, all_xs, all_masks = self.PaliGemma.llm(
        #     [next_prefix_tokens, next_suffix_tokens], mask=next_attn_mask, positions=next_positions
        # )

        # # get next middle(9th) hidden
        # # next_hidden_idx = 5
        # # # last token
        # # next_hidden_mask = all_masks[0][next_hidden_idx]
        # # last_true_idx = jnp.argmax(next_hidden_mask[:, ::-1], axis=1)
        # # last_true_idx = next_hidden_mask.shape[1] - 1 - last_true_idx
        # # batch_indices = jnp.arange(next_hidden_mask.shape[0])
        # # next_hidden = all_xs[0][next_hidden_idx] [batch_indices, last_true_idx]
        
        # # get next middle(5th) hidden
        # # next_hidden_idx = 9
        # # # mean token
        # # next_hidden_mask = all_masks[0][next_hidden_idx]
        # # next_hidden = all_xs[0][next_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
        # # mask_expanded = next_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
        # # masked_hidden = next_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)
        # # sum_hidden = jnp.sum(masked_hidden, axis=1)  # shape: (batch, hidden_dim)
        # # count_true = jnp.sum(next_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
        # # next_hidden = jnp.where(count_true > 0, sum_hidden / count_true, 0.0)

        # # get next middle(5th) hidden
        # next_hidden_idx = 9
        # # # mean token
        # next_hidden_mask = all_masks[0][next_hidden_idx]
        # next_hidden = all_xs[0][next_hidden_idx]  # shape: (batch, seq_len, hidden_dim)
        # mask_expanded = next_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
        # next_hidden = next_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)

        # next_hidden = jax.lax.stop_gradient(next_hidden)

        # # mean/last token
        # # next_hidden_pred = self.hidden_pred(curr_hidden)
        # # hidden_pred_loss = jnp.mean(jnp.square(next_hidden_pred - next_hidden), axis=-1)
        
        # # token mean
        # next_hidden_pred = self.hidden_pred(curr_hidden)  # shape: (batch, seq_len, hidden_dim)
        # # 计算每个token的loss，然后在token维度上取平均
        # token_loss = jnp.square(next_hidden_pred - next_hidden)  # shape: (batch, seq_len, hidden_dim)
        # token_loss_mean = jnp.mean(token_loss, axis=-1)  # shape: (batch, seq_len) - 在hidden_dim维度求均值
        # # 只对mask为True的token计算loss，然后在token维度取平均
        # masked_token_loss = token_loss_mean * curr_hidden_mask  # shape: (batch, seq_len)
        # valid_token_count = jnp.sum(curr_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
        # hidden_pred_loss = jnp.where(
        #     valid_token_count > 0, 
        #     jnp.sum(masked_token_loss, axis=1) / jnp.squeeze(valid_token_count, axis=1), 
        #     0.0
        # )  # shape: (batch,)

        # return jnp.mean(jnp.square(v_t - u_t), axis=-1), hidden_pred_loss

    @override
    def sample_actions(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        *,
        num_steps: int | at.Int[at.Array, ""] = 10,
    ) -> _model.Actions:
        
        observation = _model.preprocess_observation(None, observation, train=False)
        # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
        # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
        dt = -1.0 / num_steps
        batch_size = observation.state.shape[0]
        noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

        # first fill KV cache with a forward pass of the prefix
        # jax.debug.print("{} ------------", observation.images.keys())
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        positions = jnp.cumsum(prefix_mask, axis=1) - 1
        xs, kv_cache, all_xs, all_masks = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
        def step(carry):
            x_t, time = carry
            suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(
                observation, x_t, jnp.broadcast_to(time, batch_size)
            )
            # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
            # other
            suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
            # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
            # prefix tokens
            prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
            # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
            # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
            full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
            assert full_attn_mask.shape == (
                batch_size,
                suffix_tokens.shape[1],
                prefix_tokens.shape[1] + suffix_tokens.shape[1],
            )
            # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
            positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

            (prefix_out, suffix_out), _, _, _ = self.PaliGemma.llm(
                [None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache
            )
            assert prefix_out is None
            v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])

            return x_t + dt * v_t, time + dt

        def cond(carry):
            x_t, time = carry
            # robust to floating-point error
            return time >= -dt / 2

        x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))

        all_curr_hidden = []
        # lask token
        # for idx in range(18):
        #     curr_hidden_mask = all_masks[0][idx]
        #     last_true_idx = jnp.argmax(curr_hidden_mask[:, ::-1], axis=1)
        #     last_true_idx = curr_hidden_mask.shape[1] - 1 - last_true_idx
        #     batch_indices = jnp.arange(curr_hidden_mask.shape[0])
        #     curr_hidden = all_xs[0][idx] [batch_indices, last_true_idx]
        #     all_curr_hidden.append(curr_hidden)
        # all token
        for idx in range(18):
            curr_hidden_mask = all_masks[0][idx]
            curr_hidden = all_xs[0][idx]  # shape: (batch, seq_len, hidden_dim)
            
            # 扩展mask维度以匹配hidden的维度
            mask_expanded = curr_hidden_mask[:, :, None]  # shape: (batch, seq_len, 1)
            
            # 将非True位置的hidden置为0
            masked_hidden = curr_hidden * mask_expanded  # shape: (batch, seq_len, hidden_dim)
            
            # 计算每个batch中True位置的平均
            sum_hidden = jnp.sum(masked_hidden, axis=1)  # shape: (batch, hidden_dim)
            count_true = jnp.sum(curr_hidden_mask, axis=1, keepdims=True)  # shape: (batch, 1)
            
            # 避免除零，如果count_true为0，则结果为0
            avg_hidden = jnp.where(count_true > 0, sum_hidden / count_true, 0.0)
            
            all_curr_hidden.append(avg_hidden)
        # 按第0维拼接所有curr_hidden
        concatenated_hidden = jnp.concatenate(all_curr_hidden, axis=0)
        
        return x_0, concatenated_hidden
