import logging

import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override

from openpi.models import model as _model
from openpi.models import pi0_config
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
from openpi.models.action_comparator import ActionComparator

logger = logging.getLogger("openpi")


def make_attn_mask(input_mask, mask_ar):
    """Adapted from big_vision.

    Tokens can attend to valid inputs tokens which have a cumulative mask_ar
    smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
    setup several types of attention, for example:

      [[1 1 1 1 1 1]]: pure causal attention.

      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
          themselves and the last 3 tokens have a causal attention. The first
          entry could also be a 1 without changing behaviour.

      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
          block can attend all previous blocks and all tokens on the same block.

    Args:
      input_mask: bool[B, N] true if its part of the input, false if padding.
      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
        it and false where it shares the same attention mask as the previous token.
    """
    mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
    cumsum = jnp.cumsum(mask_ar, axis=1)
    attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
    return jnp.logical_and(attn_mask, valid_mask)


@at.typecheck
def posemb_sincos(
    pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
    """Computes sine-cosine positional embedding vectors for scalar positions."""
    if embedding_dim % 2 != 0:
        raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")

    fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
    period = min_period * (max_period / min_period) ** fraction
    sinusoid_input = jnp.einsum(
        "i,j->ij",
        pos,
        1.0 / period * 2 * jnp.pi,
        precision=jax.lax.Precision.HIGHEST,
    )
    return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)


class Pi0(_model.BaseModel):
    def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
        self.pi05 = config.pi05
        paligemma_config = _gemma.get_config(config.paligemma_variant)
        action_expert_config = _gemma.get_config(config.action_expert_variant)
        # TODO: rewrite gemma in NNX. For now, use bridge.
        llm = nnx_bridge.ToNNX(
            _gemma.Module(
                configs=[paligemma_config, action_expert_config],
                embed_dtype=config.dtype,
                adarms=config.pi05,
            )
        )
        llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
        img = nnx_bridge.ToNNX(
            _siglip.Module(
                num_classes=paligemma_config.width,
                variant="So400m/14",
                pool_type="none",
                scan=True,
                dtype_mm=config.dtype,
            )
        )
        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
        self.PaliGemma = nnx.Dict(llm=llm, img=img)
        self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
        if config.pi05:
            self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
            self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
        else:
            self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
            self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
            self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
        self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)

        # This attribute gets automatically set by model.train() and model.eval().
        self.deterministic = True

        # Comparator-related (optional)
        self.enable_action_comparator = config.enable_action_comparator
        if self.enable_action_comparator:
            # Learnable queries used only for comparator feature aggregation
            q_shape = (config.num_comparison_queries, paligemma_config.width)
            self.comparison_queries = nnx.Param(
                0.01 * jax.random.normal(rngs.params(), q_shape, dtype=jnp.float32)
            )

            # Transform-based comparator (width aligned to gemma_300m)
            self.action_comparator = ActionComparator(
                hidden_dim=config.comparator_hidden_dim,
                num_heads=config.comparator_num_heads,
                depth=paligemma_config.depth,
                dropout=config.comparator_dropout,
                action_dim=config.comparator_action_dim,
                action_horizon=self.action_horizon,
                vlm_dim=paligemma_config.width,
                state_dim=config.comparator_state_dim,
                rngs=rngs,
            )
            # Cache comparator input dims to avoid fragile module introspection
            self.comparator_action_dim = config.comparator_action_dim
            self.comparator_state_dim = config.comparator_state_dim

    @at.typecheck
    def embed_prefix(
        self, obs: _model.Observation
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        input_mask = []
        ar_mask = []
        tokens = []
        # embed images
        for name in obs.images:
            image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)

            tokens.append(image_tokens)
            input_mask.append(
                einops.repeat(
                    obs.image_masks[name],
                    "b -> b s",
                    s=image_tokens.shape[1],
                )
            )
            # image tokens attend to each other
            ar_mask += [False] * image_tokens.shape[1]

        # add language (aka tokenized inputs)
        if obs.tokenized_prompt is not None:
            tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
            tokens.append(tokenized_inputs)
            input_mask.append(obs.tokenized_prompt_mask)
            # full attention between image and language inputs
            ar_mask += [False] * tokenized_inputs.shape[1]
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask

    def _embed_prefix_with_queries(
        self, obs: _model.Observation
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"], int]:
        """Prefix embedding for comparator: append learnable queries at the end.

        Returns tokens, input_mask, ar_mask, and the number of queries appended.
        """
        tokens, input_mask, ar_mask = self.embed_prefix(obs)
        if not self.enable_action_comparator:
            return tokens, input_mask, ar_mask, 0
        bsz = tokens.shape[0]
        num_q = self.comparison_queries.value.shape[0]
        # Repeat queries across batch
        q_tokens = einops.repeat(self.comparison_queries.value, "q d -> b q d", b=bsz)
        tokens = jnp.concatenate([tokens, q_tokens], axis=1)
        input_mask = jnp.concatenate([input_mask, jnp.ones((bsz, num_q), dtype=jnp.bool_)], axis=1)
        # Start a new AR block for queries so earlier tokens cannot attend to them
        # First token True, rest False to share causal mask within the block
        q_ar = [True] + ([False] * (num_q - 1)) if num_q > 0 else []
        ar_mask = jnp.concatenate([ar_mask, jnp.array(q_ar)], axis=0)
        return tokens, input_mask, ar_mask, num_q

    def _prefill_vlm_with_queries(self, obs: _model.Observation):
        """Run a single VLM prefix pass with learnable queries appended.

        Returns:
          context: dict with keys
            - kv_cache_no_queries: KV cache with queries stripped for action expert reuse
            - raw_features_by_layer: [L][B, N_raw, D]
            - core_features_by_layer: [L][B, Nq, D]
            - prefix_mask: mask of original (no queries) prefix for later use
        """
        prefix_tokens, prefix_mask_noq, ar_mask_noq = self.embed_prefix(obs)
        tokens, input_mask, ar_mask, num_q = self._embed_prefix_with_queries(obs)
        attn_mask = make_attn_mask(input_mask, ar_mask)
        positions = jnp.cumsum(input_mask, axis=1) - 1
        # Forward with queries and request per-layer features for expert 0
        (prefix_out, _), kv_cache, layer_feats = self.PaliGemma.llm(
            [tokens, None], mask=attn_mask, positions=positions, return_layer_features=True
        )
        assert prefix_out is not None
        if num_q > 0:
            raw_len = prefix_out.shape[1] - num_q
            raw_features = prefix_out[:, :raw_len]
            core_features = prefix_out[:, raw_len:]
        else:
            raw_features = prefix_out
            core_features = jnp.zeros((prefix_out.shape[0], 0, prefix_out.shape[-1]), dtype=prefix_out.dtype)

        # Split per-layer features into raw/core for each layer
        raw_features_by_layer = []
        core_features_by_layer = []
        for l in range(layer_feats.shape[0]):
            lf = layer_feats[l]
            if num_q > 0:
                raw_l = lf[:, :raw_len]
                core_l = lf[:, raw_len:]
            else:
                raw_l = lf
                core_l = jnp.zeros((lf.shape[0], 0, lf.shape[-1]), dtype=lf.dtype)
            raw_features_by_layer.append(raw_l)
            core_features_by_layer.append(core_l)

        # Strip queries from kv_cache per layer to match masks during action decoding
        k_all, v_all = kv_cache
        if num_q > 0:
            k_all = k_all[:, :, :-num_q, :, :]
            v_all = v_all[:, :, :-num_q, :, :]
        kv_cache_no_queries = (k_all, v_all)

        return {
            "kv_cache_no_queries": kv_cache_no_queries,
            "raw_features_by_layer": raw_features_by_layer,
            "core_features_by_layer": core_features_by_layer,
            "prefix_mask": prefix_mask_noq,
        }

    @at.typecheck
    def embed_suffix(
        self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
    ) -> tuple[
        at.Float[at.Array, "b s emb"],
        at.Bool[at.Array, "b s"],
        at.Bool[at.Array, " s"],
        at.Float[at.Array, "b emb"] | None,
    ]:
        input_mask = []
        ar_mask = []
        tokens = []
        if not self.pi05:
            # add a single state token
            state_token = self.state_proj(obs.state)[:, None, :]
            tokens.append(state_token)
            input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
            # image/language inputs do not attend to state or actions
            ar_mask += [True]

        action_tokens = self.action_in_proj(noisy_actions)
        # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
        time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
        if self.pi05:
            # time MLP (for adaRMS)
            time_emb = self.time_mlp_in(time_emb)
            time_emb = nnx.swish(time_emb)
            time_emb = self.time_mlp_out(time_emb)
            time_emb = nnx.swish(time_emb)
            action_expert_tokens = action_tokens
            adarms_cond = time_emb
        else:
            # mix timestep + action information using an MLP (no adaRMS)
            time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
            action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
            action_time_tokens = self.action_time_mlp_in(action_time_tokens)
            action_time_tokens = nnx.swish(action_time_tokens)
            action_time_tokens = self.action_time_mlp_out(action_time_tokens)
            action_expert_tokens = action_time_tokens
            adarms_cond = None
        tokens.append(action_expert_tokens)
        input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
        # image/language/state inputs do not attend to action tokens
        ar_mask += [True] + ([False] * (self.action_horizon - 1))
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)
        return tokens, input_mask, ar_mask, adarms_cond

    @override
    def compute_loss(
        self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
    ) -> at.Float[at.Array, "*b ah"]:
        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
        observation = _model.preprocess_observation(preprocess_rng, observation, train=train)

        batch_shape = actions.shape[:-2]
        noise = jax.random.normal(noise_rng, actions.shape)
        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
        time_expanded = time[..., None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

        # one big forward pass of prefix + suffix at once
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
        input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
        ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
        attn_mask = make_attn_mask(input_mask, ar_mask)
        positions = jnp.cumsum(input_mask, axis=1) - 1
        (prefix_out, suffix_out), _ = self.PaliGemma.llm(
            [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
        )
        v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])

        return jnp.mean(jnp.square(v_t - u_t), axis=-1)

    @override
    def sample_actions(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        *,
        num_steps: int | at.Int[at.Array, ""] = 10,
        noise: at.Float[at.Array, "b ah ad"] | None = None,
    ) -> _model.Actions:
        observation = _model.preprocess_observation(None, observation, train=False)
        # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
        # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
        dt = -1.0 / num_steps
        batch_size = observation.state.shape[0]
        if noise is None:
            noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

        # first fill KV cache with a forward pass of the prefix
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        positions = jnp.cumsum(prefix_mask, axis=1) - 1
        _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

        def step(carry):
            x_t, time = carry
            suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
                observation, x_t, jnp.broadcast_to(time, batch_size)
            )
            # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
            # other
            suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
            # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
            # prefix tokens
            prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
            # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
            # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
            full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
            assert full_attn_mask.shape == (
                batch_size,
                suffix_tokens.shape[1],
                prefix_tokens.shape[1] + suffix_tokens.shape[1],
            )
            # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
            positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

            (prefix_out, suffix_out), _ = self.PaliGemma.llm(
                [None, suffix_tokens],
                mask=full_attn_mask,
                positions=positions,
                kv_cache=kv_cache,
                adarms_cond=[None, adarms_cond],
            )
            assert prefix_out is None
            v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])

            return x_t + dt * v_t, time + dt

        def cond(carry):
            x_t, time = carry
            # robust to floating-point error
            return time >= -dt / 2

        x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
        return x_0

    def sample_actions_with_context(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        context: dict,
        *,
        num_steps: int | at.Int[at.Array, ""] = 10,
        noise: at.Float[at.Array, "b ah ad"] | None = None,
    ) -> _model.Actions:
        observation = _model.preprocess_observation(None, observation, train=False)
        dt = -1.0 / num_steps
        batch_size = observation.state.shape[0]
        if noise is None:
            noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

        prefix_mask = context["prefix_mask"]
        kv_cache = context["kv_cache_no_queries"]

        def step(carry):
            x_t, time = carry
            suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
                observation, x_t, jnp.broadcast_to(time, batch_size)
            )
            # masks
            suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
            prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
            full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
            positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

            (prefix_out, suffix_out), _ = self.PaliGemma.llm(
                [None, suffix_tokens],
                mask=full_attn_mask,
                positions=positions,
                kv_cache=kv_cache,
                adarms_cond=[None, adarms_cond],
            )
            assert prefix_out is None
            v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
            return x_t + dt * v_t, time + dt

        def cond(carry):
            _, time = carry
            return time >= -dt / 2

        x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
        return x_0

    def compare_actions_with_context(
        self,
        observation: _model.Observation,
        context: dict,
        action_a: _model.Actions,
        action_b: _model.Actions,
        *,
        return_stats: bool = False,
    ) -> jnp.ndarray | tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
        assert self.enable_action_comparator, "Action comparator not enabled in config"
        raws = context["raw_features_by_layer"]
        cores = context["core_features_by_layer"]
        state = observation.state
        # Use only the first comparator dims for comparison (e.g., action=7, state=8)
        a7 = action_a[..., : getattr(self, "comparator_action_dim", self.action_comparator.action_dim)]
        b7 = action_b[..., : getattr(self, "comparator_action_dim", self.action_comparator.action_dim)]
        s7 = state[..., : getattr(self, "comparator_state_dim", self.action_comparator.state_dim)]
        logits = self.action_comparator(
            action_a=a7,
            action_b=b7,
            state=s7,
            vlm_raw_by_layer=raws,
            vlm_core_by_layer=cores,
            deterministic=self.deterministic,
            return_stats=return_stats,
        )
        return logits

    def sample_and_compare(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        *,
        num_candidates: int = 2,
        num_steps: int | at.Int[at.Array, ""] = 20,
    ) -> _model.Actions:
        """Generate K candidates under a single VLM prefix pass, then pick best by comparator."""
        assert num_candidates >= 2
        observation = _model.preprocess_observation(None, observation, train=False)
        context = self._prefill_vlm_with_queries(observation)

        # Sample K candidates with different noise but same kv_cache
        rng_gen, rng_perm = jax.random.split(rng)
        keys = jax.random.split(rng_gen, num_candidates)

        def gen_action(key):
            return self.sample_actions_with_context(key, observation, context, num_steps=num_steps)

        candidates = jax.vmap(gen_action)(keys)  # [K, B, H, D]

        # Randomly permute candidates to avoid deterministic pairing bias
        perm = jax.random.permutation(rng_perm, num_candidates)
        candidates = candidates[perm]

        # Bracket-style tournament: pairwise comparisons with winners advancing
        def one_round(cands):
            K = cands.shape[0]
            winners = []
            # pairwise compare (0,1), (2,3), ...
            for i in range(0, K - 1, 2):
                a = cands[i]
                b = cands[i + 1]
                logits = self.compare_actions_with_context(observation, context, a, b)  # [B, 1]
                prob = jax.nn.sigmoid(logits)[..., 0]  # [B]
                mask = (prob >= 0.5)[..., None, None]  # [B,1,1]
                winners.append(jnp.where(mask, a, b))
            if (K % 2) == 1:
                winners.append(cands[-1])  # odd one advances
            return jnp.stack(winners, axis=0)

        cur = candidates
        while cur.shape[0] > 1:
            cur = one_round(cur)
        return cur[0]
