import dataclasses
import logging

import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override

from openpi.models import model as _model
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import copy
import numpy as np

logger = logging.getLogger("openpi")


def make_attn_mask(input_mask, mask_ar):
    """Adapted from big_vision.

    Tokens can attend to valid inputs tokens which have a cumulative mask_ar
    smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
    setup several types of attention, for example:

      [[1 1 1 1 1 1]]: pure causal attention.

      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
          themselves and the last 3 tokens have a causal attention. The first
          entry could also be a 1 without changing behaviour.

      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
          block can attend all previous blocks and all tokens on the same block.

    Args:
      input_mask: bool[B, N] true if its part of the input, false if padding.
      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
        it and false where it shares the same attention mask as the previous token.
    """
    mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
    cumsum = jnp.cumsum(mask_ar, axis=1)
    attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
    return jnp.logical_and(attn_mask, valid_mask)


@at.typecheck
def posemb_sincos(
    pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
    """Computes sine-cosine positional embedding vectors for scalar positions."""
    if embedding_dim % 2 != 0:
        raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")

    fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
    period = min_period * (max_period / min_period) ** fraction
    sinusoid_input = jnp.einsum(
        "i,j->ij",
        pos,
        1.0 / period * 2 * jnp.pi,
        precision=jax.lax.Precision.HIGHEST,
    )
    return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)


@dataclasses.dataclass(frozen=True)
class Pi0Config(_model.BaseModelConfig):
    dtype: str = "bfloat16"
    paligemma_variant: _gemma.Variant = "gemma_2b"
    action_expert_variant: _gemma.Variant = "gemma_300m"

    # Set the model specific defaults.
    action_dim: int = 32
    action_horizon: int = 50
    max_token_len: int = 48
    
    # ===== 新增：CoT 控制 =====
    use_cot: bool = True          # True 时启用思考序列（训练 + 推理）
    cot_len: int = 180             # 推理生成的 CoT token 数
    cot_loss_weight: float = 0.05   # 训练时 CoT LM 损失的权重
    cot_bos_id: int = 2            # 生成 CoT 的起始 token（BOS）；如需不同，请在运行时改配置

    @property
    @override
    def model_type(self) -> _model.ModelType:
        return _model.ModelType.PI0

    @override
    def create(self, rng: at.KeyArrayLike) -> "Pi0":
        return Pi0(self, rngs=nnx.Rngs(rng))

    @override
    def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
        image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
        image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)

        with at.disable_typechecking():
            observation_spec = _model.Observation(
                images={
                    "base_0_rgb": image_spec,
                    "left_wrist_0_rgb": image_spec,
                    "right_wrist_0_rgb": image_spec,
                },
                image_masks={
                    "base_0_rgb": image_mask_spec,
                    "left_wrist_0_rgb": image_mask_spec,
                    "right_wrist_0_rgb": image_mask_spec,
                },
                state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
                tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
                tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
                tokenized_cot_info=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
                tokenized_cot_info_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
            )
        action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)

        return observation_spec, action_spec

    def get_freeze_filter(self) -> nnx.filterlib.Filter:
        """Returns the freeze filter based on the model config."""
        filters = []
        has_lora = False
        gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
        action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
        if "lora" in self.paligemma_variant:
            filters.append(
                gemma_params_filter,
            )
            if "lora" not in self.action_expert_variant:
                # If only freeze gemma params, exclude action expert params.
                filters.append(
                    nnx.Not(action_expert_params_filter),
                )
            has_lora = True
        elif "lora" in self.action_expert_variant:
            filters.append(
                action_expert_params_filter,
            )
            has_lora = True

        if has_lora:
            # If any lora is used, exclude all lora params.
            filters.append(
                nnx.Not(nnx_utils.PathRegex(".*lora.*")),
            )
        if not filters:
            return nnx.Nothing
        return nnx.All(*filters)


class Pi0(_model.BaseModel):
    def __init__(self, config: Pi0Config, rngs: nnx.Rngs):
        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
        self.config = config
        paligemma_config = _gemma.get_config(config.paligemma_variant)
        action_expert_config = _gemma.get_config(config.action_expert_variant)
        # TODO: rewrite gemma in NNX. For now, use bridge.
        llm = nnx_bridge.ToNNX(
            _gemma.Module(
                configs=[paligemma_config, action_expert_config],
                embed_dtype=config.dtype,
            )
        )
        llm.lazy_init(rngs=rngs, method="init")
        img = nnx_bridge.ToNNX(
            _siglip.Module(
                num_classes=paligemma_config.width,
                variant="So400m/14",
                pool_type="none",
                scan=True,
                dtype_mm=config.dtype,
            )
        )
        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
        self.PaliGemma = nnx.Dict(llm=llm, img=img)
        self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
        self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
        self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
        self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
        self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)

    @at.typecheck
    def embed_prefix(
        self, obs: _model.Observation
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        #把图像和输入的任务prompt编码成tokens并连接起来，分别用的img和llm.embed
        input_mask = []
        ar_mask = []
        tokens = []
        # embed images
        for name in obs.images:
            image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)

            tokens.append(image_tokens)
            input_mask.append(
                einops.repeat(
                    obs.image_masks[name],
                    "b -> b s",
                    s=image_tokens.shape[1],
                )
            )
            # image tokens attend to each other
            ar_mask += [False] * image_tokens.shape[1]

        # add language (aka tokenized inputs)
        if obs.tokenized_prompt is not None:
            tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
            tokens.append(tokenized_inputs)
            input_mask.append(obs.tokenized_prompt_mask)
            # full attention between image and language inputs
            ar_mask += [False] * tokenized_inputs.shape[1]
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    @at.typecheck
    def embed_cot_info(
        self, obs: _model.Observation
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        #把cot信息编码成tokens并连接起来
        input_mask = []
        ar_mask = []
        tokens = []
        if obs.tokenized_cot_info is not None:
            tokenized_cot_info = self.PaliGemma.llm(obs.tokenized_cot_info, method="embed")
            tokens.append(tokenized_cot_info)
            input_mask.append(obs.tokenized_cot_info_mask)
            # full attention between image and language inputs
            ar_mask += [True] * tokenized_cot_info.shape[1]
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    @at.typecheck
    def embed_suffix(
        self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        #把状态、时间、动作编码成tokens并连接起来，扩散时间步和动作是编码在一起的
        input_mask = []
        ar_mask = []
        tokens = []
        # add a single state token
        state_token = self.state_proj(obs.state)[:, None, :]
        tokens.append(state_token)
        input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
        # image/language inputs do not attend to state or actions
        ar_mask += [True]

        # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
        time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
        # mix timestep + action information using an MLP
        action_tokens = self.action_in_proj(noisy_actions)
        time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
        action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
        action_time_tokens = self.action_time_mlp_in(action_time_tokens)
        action_time_tokens = nnx.swish(action_time_tokens)
        action_time_tokens = self.action_time_mlp_out(action_time_tokens)
        tokens.append(action_time_tokens)
        input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
        # image/language/state inputs do not attend to action tokens
        ar_mask += [True] + ([False] * (self.action_horizon - 1))
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    @override
    def compute_loss(
        self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
    ) -> at.Float[at.Array, "*b ah"]:
        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
        observation = _model.preprocess_observation(preprocess_rng, observation, train=train)

        batch_shape = actions.shape[:-2]
        noise = jax.random.normal(noise_rng, actions.shape)
        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
        time_expanded = time[..., None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

        # === prefix (图像 + prompt) ===
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)

        # === 可选：插入 CoT（来自监督 token）===
        if self.config.use_cot and (observation.tokenized_cot_info is not None):
            cot_tokens, cot_mask, cot_ar_mask = self.embed_cot_info(observation)
            # expert0 流 = prefix + cot
            expert0_tokens = jnp.concatenate([prefix_tokens, cot_tokens], axis=1)
            expert0_mask = jnp.concatenate([prefix_mask, cot_mask], axis=1)
            expert0_ar = jnp.concatenate([prefix_ar_mask, cot_ar_mask], axis=0)
            cot_len = cot_tokens.shape[1]
            cot_offset = prefix_tokens.shape[1]
            have_cot = True
        else:
            expert0_tokens = prefix_tokens
            expert0_mask = prefix_mask
            expert0_ar = prefix_ar_mask
            cot_len = 0
            cot_offset = expert0_tokens.shape[1]
            have_cot = False

        # === suffix (state + time + action tokens for diffusion) ===
        suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)

        # === 拼接两专家的 mask，用于一次前向 ===
        # 注意：llm 期望的 mask/positions 是「总长度」(expert0_len + expert1_len)
        input_mask = jnp.concatenate([expert0_mask, suffix_mask], axis=1)   # [B, N]
        ar_mask = jnp.concatenate([expert0_ar, suffix_ar_mask], axis=0)     # [N]
        attn_mask = make_attn_mask(input_mask, ar_mask)                     # [B, N, N]
        positions = jnp.cumsum(input_mask, axis=1) - 1                      # [B, N]

        (expert0_out, suffix_out), _ = self.PaliGemma.llm(
            [expert0_tokens, suffix_tokens], mask=attn_mask, positions=positions
        )

        # 动作损失（原样）
        v_t = self.action_out_proj(suffix_out[:, -self.action_horizon:])
        action_loss = jnp.mean(jnp.square(v_t - u_t), axis=-1)  # [*B, AH]

        # === 可选：CoT LM 损失 ===
        if self.config.use_cot and have_cot:
            # 取出 expert0_out 中对应 CoT 段的隐状态
            cot_hidden = expert0_out[:, cot_offset: cot_offset + cot_len, :]  # [B, Lc, D]
            # next-token 预测：logits 对应到 [0..Lc-2] 预测 target 的 [1..Lc-1]
            logits = self.PaliGemma.llm(cot_hidden, method="decode_logits")   # [B, Lc, V]
            logits = logits[:, :-1, :]                                        # [B, Lc-1, V]
            targets = observation.tokenized_cot_info[:, 1:1 + (cot_len - 1)]  # [B, Lc-1]
            target_mask = observation.tokenized_cot_info_mask[:, 1:1 + (cot_len - 1)]  # [B, Lc-1]

            # 稳定性：float32 交叉熵
            log_probs = jax.nn.log_softmax(logits.astype(jnp.float32), axis=-1)
            nll = -jnp.take_along_axis(log_probs, targets[..., None], axis=-1).squeeze(-1)  # [B, Lc-1]
            nll = jnp.where(target_mask, nll, 0.0)
            denom = jnp.clip(jnp.sum(target_mask, axis=-1), a_min=1.0)  # [B]
            cot_loss = jnp.sum(nll, axis=-1) / denom                     # [B]
            cot_loss = cot_loss[..., None]                               # broadcast 到 [*B, 1] 以便和 action_loss 对齐
            total_loss = action_loss + self.config.cot_loss_weight * cot_loss
        else:
            total_loss = action_loss

        return total_loss

    @override
    def sample_actions(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        *,
        num_steps: int | at.Int[at.Array, ""] = 10,
    ) -> _model.Actions:
        observation = _model.preprocess_observation(None, observation, train=False)
        dt = -1.0 / num_steps
        batch_size = observation.state.shape[0]
        noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

        # === prefix（填 KV cache）===
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        positions = jnp.cumsum(prefix_mask, axis=1) - 1
        (prefix_pre_logits, _), kv_cache = self.PaliGemma.llm(
            [prefix_tokens, None], mask=prefix_attn_mask, positions=positions
        )

        # === 可选：先生成 CoT（expert0，自回归长度 cot_len），使用 JAX while_loop 保持形状不变 ===
        if self.config.use_cot and (self.config.cot_len > 0):
            cot_len = int(self.config.cot_len)

            # 计算 prefix 最后一个位置的预激活并得到初始 logits（仅用于确定形状，step 0 会强制 BOS）
            last_prefix_indices = jnp.sum(prefix_mask, axis=-1).astype(jnp.int32) - 1
            last_prefix_pre = jnp.take_along_axis(
                prefix_pre_logits, last_prefix_indices[:, None, None], axis=1
            )
            last_logit = self.PaliGemma.llm(last_prefix_pre, method="decode_logits")  # [B, 1, V]

            # 预留并填充 KV cache 的时间维度，以容纳最多 cot_len 个 token
            kv_cache = jax.tree.map(
                lambda x: jnp.pad(x, ((0, 0), (0, 0), (0, cot_len), (0, 0), (0, 0))),
                kv_cache,
            )

            # 为单步解码构造固定形状的注意力掩码：(B, 1, P + 1 + cot_len)
            # 额外的 1 用于当前 token 的 KV 槽位，保持与 kv_cache 的 S 维一致
            prefix_attn_mask_single = einops.repeat(prefix_mask, "b p -> b s p", s=1)
            attn_mask = jnp.pad(prefix_attn_mask_single, ((0, 0), (0, 0), (0, cot_len + 1)))
            # 初始令最后一列可见，用作第一个生成步的自注意占位（当前位置）
            attn_mask = attn_mask.at[:, :, -1].set(True)

            # 输出 token 缓冲区，固定为 cot_len 长度，step=0 强制为 BOS
            output_tokens = jnp.zeros((batch_size, cot_len), dtype=jnp.int32)

            # prefix 的位置 id（最后一个位置）
            prefix_positions = positions

            @at.typecheck
            def _wrap_cache(cache_appended: at.Float[at.Array, "l b t k h"], step: at.Int[at.Array, ""]) -> at.Float[at.Array, "l b t-1 k h"]:
                new_value = cache_appended[:, :, -1]
                cache = cache_appended[:, :, :-1]
                # 将新增的一步写入到固定的时间槽：prefix_len + 1 + step（+1 为当前槽位）
                return jax.lax.dynamic_update_index_in_dim(
                    cache, new_value, prefix_mask.shape[1] + 1 + step, axis=2
                )

            def decode_step(carry):
                last_logit, output_tokens, kv_cache, attn_mask, step = carry

                # 采样 token（贪心），step==0 时强制为 BOS
                token = jnp.argmax(last_logit, axis=-1)  # [B, 1]
                bos = jnp.full_like(token, int(self.config.cot_bos_id))
                token = jnp.where(step == 0, bos, token)

                # 写入输出缓冲区的第 step 列
                one_hot = jax.nn.one_hot(step, cot_len, dtype=output_tokens.dtype)  # [L]
                output_tokens = output_tokens + (token.squeeze(-1) * one_hot[None, :])

                # 前向一解码步：使用当前 token 更新 KV，并得到下一步的 logits
                token_embedding = self.PaliGemma.llm(token, method="embed")
                positions_step = prefix_positions[:, [-1]] + step + 1  # [B, 1]
                (last_pre_logit, _), kv_cache_appended = self.PaliGemma.llm(
                    [token_embedding, None], mask=attn_mask, positions=positions_step, kv_cache=kv_cache
                )
                next_last_logit = self.PaliGemma.llm(last_pre_logit, method="decode_logits")
                kv_cache = jax.tree.map(lambda x: _wrap_cache(x, step), kv_cache_appended)

                # 使下一个时间槽可见（prefix_len + 1 + step）
                attn_mask = attn_mask.at[:, :, prefix_mask.shape[1] + 1 + step].set(True)

                return next_last_logit, output_tokens, kv_cache, attn_mask, step + 1

            def decode_cond(carry):
                _, _, _, _, step = carry
                return step < cot_len

            _, cur_ids, kv_cache, _, _ = jax.lax.while_loop(
                decode_cond,
                decode_step,
                (last_logit, output_tokens, kv_cache, attn_mask, jnp.array(0, dtype=jnp.int32)),
            )

            cot_hist_len = cur_ids.shape[1]
        else:
            cot_hist_len = 0  # 未生成
        # === 动作扩散采样（复用 prefix + CoT 的 KV cache）===
        def step(carry):
            x_t, time = carry
            suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(
                observation, x_t, jnp.broadcast_to(time, batch_size)
            )
            # suffix 内部的因果
            suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
            # 前缀可见范围：prefix + 已生成的 CoT 历史
            if self.config.use_cot and cot_hist_len > 0:
                prefix_plus_cot_mask = jnp.concatenate(
                    [prefix_mask, jnp.ones((batch_size, cot_hist_len), dtype=jnp.bool_)], axis=1
                )
            else:
                prefix_plus_cot_mask = prefix_mask
            prefix_attn_mask = einops.repeat(prefix_plus_cot_mask, "b p -> b s p", s=suffix_tokens.shape[1])
            full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
            positions = jnp.sum(prefix_plus_cot_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
            (prefix_out, suffix_out), _ = self.PaliGemma.llm(
                [None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache
            )
            assert prefix_out is None
            v_t = self.action_out_proj(suffix_out[:, -self.action_horizon:])
            return x_t + dt * v_t, time + dt

        def cond(carry):
            x_t, time = carry
            return time >= -dt / 2

        x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
        return x_0