import logging
from typing import Literal

import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override
from einops import rearrange, reduce

from openpi.models import model as _model
from openpi.models import pi0_config
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
from openpi.training.utils import concat_observations


logger = logging.getLogger("openpi")


def make_attn_mask(input_mask, mask_ar):
    """Adapted from big_vision.

    Tokens can attend to valid inputs tokens which have a cumulative mask_ar
    smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
    setup several types of attention, for example:

      [[1 1 1 1 1 1]]: pure causal attention.

      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
          themselves and the last 3 tokens have a causal attention. The first
          entry could also be a 1 without changing behaviour.

      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
          block can attend all previous blocks and all tokens on the same block.

    Args:
      input_mask: bool[B, N] true if its part of the input, false if padding.
      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
        it and false where it shares the same attention mask as the previous token.
    """
    mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
    cumsum = jnp.cumsum(mask_ar, axis=1)
    attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
    return jnp.logical_and(attn_mask, valid_mask)


@at.typecheck
def posemb_sincos(
    pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
    """Computes sine-cosine positional embedding vectors for scalar positions."""
    if embedding_dim % 2 != 0:
        raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")

    fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
    period = min_period * (max_period / min_period) ** fraction
    sinusoid_input = jnp.einsum(
        "i,j->ij",
        pos,
        1.0 / period * 2 * jnp.pi,
        precision=jax.lax.Precision.HIGHEST,
    )
    return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)


class Pi0IQL(_model.BaseModel):
    def __init__(self, config: pi0_config.Pi0IQLConfig, rngs: nnx.Rngs):
        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
        self.pi05 = config.pi05
        paligemma_config = _gemma.get_config(config.paligemma_variant)
        action_expert_config = _gemma.get_config(config.action_expert_variant)

        self.gamma = config.gamma
        
        # Initialize main model components
        # TODO: rewrite gemma in NNX. For now, use bridge.
        llm = nnx_bridge.ToNNX(
            _gemma.Module(
                configs=[paligemma_config, action_expert_config],
                embed_dtype=config.dtype,
                adarms=config.pi05,
            )
        )
        llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
        img = nnx_bridge.ToNNX(
            _siglip.Module(
                num_classes=paligemma_config.width,
                variant="So400m/14",
                pool_type="none",
                scan=True,
                dtype_mm=config.dtype,
            )
        )
        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
        self.PaliGemma = nnx.Dict(llm=llm, img=img)
        self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
        if config.pi05:
            self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
            self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
        else:
            self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
            self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
            self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
        self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)

        # Value head (after VLM): support an ensemble of ValueHeads.
        # Read ensemble size from config if present, otherwise default to 1.
        self.value_head = _gemma.ValueHead(
            input_dim=paligemma_config.width,
            hidden_sizes=(1024, 512, 256, 128),
            output_dim=1,
            activation="relu",
            bias_last=True,
            rngs=rngs,
        )
        self.q_head = _gemma.ValueHead(
            input_dim=paligemma_config.width + config.action_dim * self.action_horizon,
            hidden_sizes=(1024, 512, 256, 128),
            output_dim=1,
            activation="relu",
            bias_last=True,
            rngs=rngs,
        )

    def embed_prefix(
        self, obs: _model.Observation
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        input_mask = []
        ar_mask = []
        tokens = []
        # embed images
        for name in obs.images:
            image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)

            tokens.append(image_tokens)
            input_mask.append(
                einops.repeat(
                    obs.image_masks[name],
                    "b -> b s",
                    s=image_tokens.shape[1],
                )
            )
            # image tokens attend to each other
            ar_mask += [False] * image_tokens.shape[1]

        # add language (aka tokenized inputs)
        if obs.tokenized_prompt is not None:
            tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
            tokens.append(tokenized_inputs)
            input_mask.append(obs.tokenized_prompt_mask)
            # full attention between image and language inputs
            ar_mask += [False] * tokenized_inputs.shape[1]
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    @at.typecheck
    def embed_suffix(
        self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
    ) -> tuple[
        at.Float[at.Array, "b s emb"],
        at.Bool[at.Array, "b s"],
        at.Bool[at.Array, " s"],
        at.Float[at.Array, "b emb"] | None,
    ]:
        input_mask = []
        ar_mask = []
        tokens = []
        if not self.pi05:
            # add a single state token
            state_token = self.state_proj(obs.state)[:, None, :]
            tokens.append(state_token)
            input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
            # image/language inputs do not attend to state or actions
            ar_mask += [True]

        action_tokens = self.action_in_proj(noisy_actions)
        # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
        time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
        if self.pi05:
            # time MLP (for adaRMS)
            time_emb = self.time_mlp_in(time_emb)
            time_emb = nnx.swish(time_emb)
            time_emb = self.time_mlp_out(time_emb)
            time_emb = nnx.swish(time_emb)
            action_expert_tokens = action_tokens
            adarms_cond = time_emb
        else:
            # mix timestep + action information using an MLP (no adaRMS)
            time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
            action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
            action_time_tokens = self.action_time_mlp_in(action_time_tokens)
            action_time_tokens = nnx.swish(action_time_tokens)
            action_time_tokens = self.action_time_mlp_out(action_time_tokens)
            action_expert_tokens = action_time_tokens
            adarms_cond = None
        tokens.append(action_expert_tokens)
        input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
        # image/language/state inputs do not attend to action tokens
        ar_mask += [True] + ([False] * (self.action_horizon - 1))
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask, adarms_cond

    def get_value_from_vlm(
        self,
        prefix_output: at.Float[at.Array, "b s emb"],
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
    ):

        if mode == "mean_token":
            first_part = prefix_output[:, :512]
            second_part = prefix_output[:, 768:768 + (200 if self.pi05 else 48)]
            pooled = jnp.mean(jnp.concatenate([first_part, second_part], axis=1), axis=1)
        elif mode == "last_token":
            pooled = prefix_output[:, -1]
        elif mode == "first_token":
            pooled = prefix_output[:, 0]
        else:
            raise ValueError(mode)

        values = self.value_head(pooled)
        return values

    def compute_values(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
    ) -> at.Float[at.Array, "b"]:
        """Compute value estimates for observations (for PPO).
        
        This method computes the value function V(s) which estimates the expected return
        from a given state. Used in PPO for advantage estimation via GAE.
        
        Args:
            rng: Random key for preprocessing
            observation: Observations to compute values for
            mode: Pooling mode for value head (default: "mean_token")
            
        Returns:
            Value estimates, shape (batch_size,)
        """
        # Preprocess observation
        observation = _model.preprocess_observation(rng, observation, train=False)

        # Get prefix embeddings
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
        
        # Forward pass through VLM to get representations
        (prefix_output, _), _ = jax.lax.stop_gradient(self.PaliGemma.llm(
            [prefix_tokens, None], 
            mask=prefix_attn_mask, 
            positions=prefix_positions
        ))
        
        # Compute values from prefix output using value head
        values = self.get_value_from_vlm(prefix_output, mode=mode)
        
        return values

    def compute_q_values(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        actions: _model.Actions,
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
    ) -> at.Float[at.Array, "b"]:
        """计算 Q(s, a) - 状态-动作价值函数
        
        Args:
            rng: Random key for preprocessing
            observation: 当前状态观察
            actions: 动作序列 [B, action_horizon, action_dim]
            mode: Pooling mode for VLM features
            
        Returns:
            Q values, shape (batch_size,)
        """
        # Preprocess observation
        observation = _model.preprocess_observation(rng, observation, train=False)
        
        # Get VLM features (state representation)
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
        (prefix_output, _), _ = jax.lax.stop_gradient(self.PaliGemma.llm(
            [prefix_tokens, None], 
            mask=prefix_attn_mask, 
            positions=prefix_positions
        ))
        
        # Pool VLM features to get state representation
        if mode == "mean_token":
            first_part = prefix_output[:, :512]
            second_part = prefix_output[:, 768:768 + (200 if self.pi05 else 48)]
            state_features = jnp.mean(jnp.concatenate([first_part, second_part], axis=1), axis=1)
        elif mode == "last_token":
            state_features = prefix_output[:, -1]
        elif mode == "first_token":
            state_features = prefix_output[:, 0]
        else:
            raise ValueError(mode)
        
        # Encode actions - use mean pooling over action_horizon
        action_features = actions.reshape(actions.shape[0], self.action_horizon * self.action_dim)
        
        # Concatenate state and action features
        sa_features = jnp.concatenate([state_features, action_features], axis=-1)
        
        # Compute Q-value
        q_values = self.q_head(sa_features)
        
        return q_values

    def compute_value_loss(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        prev_values_all: at.Float[at.Array, "b e"],
        returns: at.Float[at.Array, "b"],
        value_clip: float = 0.2,
        huber_delta: float = 10.0,
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
        loss_mask: at.Bool[at.Array, "b"] | None = None,
    ):

        observation = _model.preprocess_observation(rng, observation, train=False)

        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1

        (prefix_output, _), _ = self.PaliGemma.llm(
            [prefix_tokens, None],
            mask=prefix_attn_mask,
            positions=prefix_positions,
        )

        # [B, E]
        current_values_all = self.get_value_from_vlm(
            jax.lax.stop_gradient(prefix_output),
            mode=mode,
        )

        # PPO value clipping
        value_pred_clipped_all = prev_values_all + jnp.clip(
            current_values_all - prev_values_all,
            -value_clip,
            value_clip,
        )

        # -------- vectorized Huber loss --------
        def huber(error, delta):
            abs_e = jnp.abs(error)
            quad = jnp.minimum(abs_e, delta)
            lin = abs_e - quad
            return 0.5 * quad** 2 + delta * lin

        errors = returns[:, None] - current_values_all
        errors_clipped = returns[:, None] - value_pred_clipped_all

        loss_orig = huber(errors, huber_delta)
        loss_clip = huber(errors_clipped, huber_delta)

        value_loss = jnp.maximum(loss_orig, loss_clip)  # [B, E]

        if loss_mask is not None:
            value_loss = value_loss * loss_mask[:, None]
            critic_loss = reduce(value_loss, 'b e ->', 'sum') / jnp.sum(loss_mask)
        else:
            critic_loss = reduce(value_loss, 'b e ->', 'mean')

        # Average over ensemble
        critic_loss /= self.value_ensemble_size

        # Clip statistics
        clip_indicator = jnp.abs(value_pred_clipped_all - prev_values_all) > value_clip
        if loss_mask is not None:
            clip_fraction = (
                reduce(clip_indicator * loss_mask[:, None], 'b e ->', 'sum')
                / jnp.sum(loss_mask)
            )
        else:
            clip_fraction = reduce(clip_indicator, 'b e ->', 'mean')

        info = {
            "value_loss": critic_loss,
            "value_clip_fraction": clip_fraction,
            "returns": jnp.mean(returns),
            "current_values_mean": jnp.mean(current_values_all),
            "prev_values_mean": jnp.mean(prev_values_all),
        }

        return critic_loss, info

    def compute_iql_value_loss(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        target_values: at.Float[at.Array, "b"],
        expectile: float = 0.7,
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
    ) -> tuple[at.Float[at.Array, ""], dict]:
        """使用 expectile regression 训练 value function
        
        Args:
            rng: Random key
            observation: 当前状态
            target_values: 目标值（通常是 Q(s, a)）
            expectile: Expectile 参数（0.7 表示 70% 分位数）
            mode: Pooling mode
            
        Returns:
            Tuple of (loss, info_dict)
        """
        current_values = self.compute_values(rng, observation, mode=mode)  # [B]
        
        # Expectile loss
        errors = target_values - current_values  # [B]
        weights = jnp.where(errors > 0, expectile, 1.0 - expectile)
        value_loss = jnp.mean(weights * (errors ** 2))
        
        info = {
            "iql_value_loss": value_loss,
            "value_mean": jnp.mean(current_values),
            "target_value_mean": jnp.mean(target_values),
            "value_error_mean": jnp.mean(jnp.abs(errors)),
        }
        
        return value_loss, info

    def compute_iql_q_loss(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        actions: _model.Actions,
        next_observation: _model.Observation,
        rewards: at.Float[at.Array, "b"],
        gamma: float = 0.99,
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
    ) -> tuple[at.Float[at.Array, ""], dict]:
        """使用 Bellman 方程训练 Q-function
        
        Q-learning loss: (Q(s,a) - (r + γ * V(s')))^2
        
        Args:
            rng: Random key
            observation: 当前状态
            actions: 当前动作
            next_observation: 下一个状态
            rewards: 奖励
            gamma: 折扣因子
            mode: Pooling mode
            
        Returns:
            Tuple of (loss, info_dict)
        """
        rng1, rng2 = jax.random.split(rng)
        
        current_q_values = self.compute_q_values(rng1, observation, actions, mode=mode)  # [B]
        
        next_values = jax.lax.stop_gradient(
            self.compute_values(rng2, next_observation, mode=mode)
        )  # [B]
        targets = rewards + gamma * next_values  # [B]
        
        # TD error
        td_errors = targets - current_q_values  # [B]
        q_loss = jnp.mean(td_errors ** 2)
        
        info = {
            "iql_q_loss": q_loss,
            "q_mean": jnp.mean(current_q_values),
            "target_q_mean": jnp.mean(targets),
            "td_error_mean": jnp.mean(jnp.abs(td_errors)),
            "reward_mean": jnp.mean(rewards),
            "next_value_mean": jnp.mean(next_values),
        }
        
        return q_loss, info

    def compute_iql_policy_loss(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        actions: _model.Actions,
        beta: float = 3.0,
        mode: Literal["mean_token", "last_token", "first_token"] = "mean_token",
    ) -> tuple[at.Float[at.Array, ""], dict]:
        """使用 advantage-weighted regression 训练 policy
        
        Policy loss = exp(β * A(s,a)) * MSE_diffusion(s, a)
        其中 A(s,a) = Q(s,a) - V(s)
        
        Args:
            rng: Random key
            observation: 当前状态
            actions: 专家动作
            beta: AWR 温度参数（控制加权强度）
            mode: Pooling mode
            
        Returns:
            Tuple of (loss, info_dict)
        """
        rng1, rng2, rng3 = jax.random.split(rng, 3)
        
        q_values = jax.lax.stop_gradient(
            self.compute_q_values(rng1, observation, actions, mode=mode)
        )  # [B]
        values = jax.lax.stop_gradient(
            self.compute_values(rng2, observation, mode=mode)
        )  # [B]
        advantages = q_values - values  # [B]
        
        # Exponential weighting
        weights = jnp.exp(advantages * beta)
        
        # Diffusion loss
        policy_loss, policy_info = self.compute_sft_loss(rng3, observation, actions, use_advantages=True, advantages=weights, train=True)
        
        info = {
            "iql_policy_loss": policy_loss,
            "advantages_mean": jnp.mean(advantages),
            "advantages_std": jnp.std(advantages),
            "awr_weights_mean": jnp.mean(weights),
            "awr_weights_max": jnp.max(weights),
            "q_values_mean": jnp.mean(q_values),
            "v_values_mean": jnp.mean(values),
        }
        
        return policy_loss, info

    @override
    def compute_loss(
        self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
    ) -> at.Float[at.Array, "*b ah"]:
        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
        observation = _model.preprocess_observation(preprocess_rng, observation, train=train)

        batch_shape = actions.shape[:-2]
        noise = jax.random.normal(noise_rng, actions.shape)
        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
        time_expanded = time[..., None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

        # one big forward pass of prefix + suffix at once
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
        input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
        ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
        attn_mask = make_attn_mask(input_mask, ar_mask)
        positions = jnp.cumsum(input_mask, axis=1) - 1
        (prefix_out, suffix_out), _ = self.PaliGemma.llm(
            [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
        )
        v_t = self.action_out_proj(suffix_out[:, -self.action_horizon:])

        return jnp.mean(jnp.square(v_t - u_t))

    def compute_sft_loss(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        actions: _model.Actions,
        use_advantages: bool = False,
        advantages: at.Float[at.Array, ""] = None,
        train: bool = False
    ) -> tuple[at.Float[at.Array, ""], dict]:
        """Compute supervised fine-tuning loss using diffusion denoising.
        
        Args:
            rng: Random key for sampling.
            observation: Current observation.
            actions: Expert actions to learn from.
            train: Whether we're in training mode.
            
        Returns:
            Tuple of (loss, info_dict) where info_dict contains detailed metrics.
        """
        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
        
        # Preprocess observation
        observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
        
        # Generate noise and time for diffusion process
        batch_shape = actions.shape[:-2]
        noise = jax.random.normal(noise_rng, actions.shape)
        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
        time_expanded = time[..., None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions
        
        # Get prefix embeddings and compute kv_cache (frozen)
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
        # PaliGemma has two configs, so we need to pass [tokens, None] for the two experts
        _, prefix_kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=prefix_positions)
        prefix_kv_cache = jax.lax.stop_gradient(prefix_kv_cache)
        
        # Compute suffix tokens for noisy actions
        suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
        suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
        prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
        full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
        suffix_positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
        
        # Run the model to predict the noise using kv_cache
        (_, suffix_out), _ = self.PaliGemma.llm(
            [None, suffix_tokens],
            mask=full_attn_mask,
            positions=suffix_positions,
            kv_cache=prefix_kv_cache,
            adarms_cond=[None, adarms_cond]
        )
        v_t = self.action_out_proj(suffix_out[:, -self.action_horizon:])
        
        # Compute MSE loss between predicted and actual noise
        if not use_advantages:
            advantages = jnp.ones_like(v_t[..., 0])
        elif advantages is None:
            advantages = jax.lax.stop_gradient(self.get_advantages(observation, actions))
        else:
            advantages = jax.lax.stop_gradient(advantages)
        
        mse_loss = jnp.mean(jnp.square(v_t - u_t), axis=-1)
        weighted_loss = jnp.mean(advantages * mse_loss)
        
        info = {
            "loss": weighted_loss,
            "advantages": jnp.mean(advantages),
        }
        
        return weighted_loss, info

    @override
    def sample_actions(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        *,
        num_steps: int | at.Int[at.Array, ""] = 10,
        noise: at.Float[at.Array, "b ah ad"] | None = None,
    ) -> _model.Actions:
        observation = _model.preprocess_observation(None, observation, train=False)
        # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
        # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
        dt = -1.0 / num_steps
        batch_size = observation.state.shape[0]
        if noise is None:
            noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

        # first fill KV cache with a forward pass of the prefix
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        positions = jnp.cumsum(prefix_mask, axis=1) - 1
        _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

        def step(carry):
            x_t, time = carry
            suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
                observation, x_t, jnp.broadcast_to(time, batch_size)
            )
            # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
            # other
            suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
            # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
            # prefix tokens
            prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
            # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
            # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
            full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
            assert full_attn_mask.shape == (
                batch_size,
                suffix_tokens.shape[1],
                prefix_tokens.shape[1] + suffix_tokens.shape[1],
            )
            # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
            positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

            (prefix_out, suffix_out), _ = self.PaliGemma.llm(
                [None, suffix_tokens],
                mask=full_attn_mask,
                positions=positions,
                kv_cache=kv_cache,
                adarms_cond=[None, adarms_cond],
            )
            assert prefix_out is None
            v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])

            return x_t + dt * v_t, time + dt

        def cond(carry):
            x_t, time = carry
            # robust to floating-point error
            return time >= -dt / 2

        x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
        return x_0