import einops
import flax.nnx as nnx
import jax
import jax.numpy as jnp

from openpi.shared import array_typing as at


class _LayerNorm(nnx.Module):
    def __init__(self, dim: int, eps: float = 1e-5, rngs: nnx.Rngs | None = None):
        self.dim = dim
        self.eps = eps
        self.gamma = nnx.Param(jnp.ones((dim,), dtype=jnp.float32))
        self.beta = nnx.Param(jnp.zeros((dim,), dtype=jnp.float32))

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.mean((x - mean) ** 2, axis=-1, keepdims=True)
        x_hat = (x - mean) * jax.lax.rsqrt(var + self.eps)
        return x_hat * self.gamma.value + self.beta.value


class _Attention(nnx.Module):
    def __init__(self, dim: int, num_heads: int, kv_dim: int, rngs: nnx.Rngs, dropout: float = 0.0):
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        # Q comes from comparator hidden space
        self.q_proj = nnx.Linear(dim, dim, rngs=rngs)
        # K/V come from source space (self: dim; cross: vlm_dim)
        self.k_in = nnx.Linear(kv_dim, dim, rngs=rngs)
        self.v_in = nnx.Linear(kv_dim, dim, rngs=rngs)
        self.out_proj = nnx.Linear(dim, dim, rngs=rngs)
        self.dropout = dropout

    def _shape_heads(self, x):
        return einops.rearrange(x, "B T (H D) -> B H T D", H=self.num_heads)

    def __call__(self, q_in, kv_in, *, mask: jnp.ndarray | None = None, deterministic: bool = True, return_stats: bool = False):
        def _rms(x):
            x = x.astype(jnp.float32)
            return jnp.sqrt(jnp.mean(jnp.square(x) + 1e-12))

        q = self._shape_heads(self.q_proj(q_in))
        k = self._shape_heads(self.k_in(kv_in))
        v = self._shape_heads(self.v_in(kv_in))
        attn = jnp.einsum("BHTD,BHSD->BHTS", q * self.scale, k)
        if mask is not None:
            big_neg = -2.3819763e38
            attn = jnp.where(mask[:, None, :, :], attn, big_neg)
        attn = jax.nn.softmax(attn, axis=-1)
        out = jnp.einsum("BHTS,BHSD->BHTD", attn, v)
        out = einops.rearrange(out, "B H T D -> B T (H D)")
        out_proj = self.out_proj(out)
        if return_stats:
            stats = {
                "q_in_rms": _rms(q_in),
                "kv_in_rms": _rms(kv_in),
                "q_proj_rms": _rms(q),
                "k_proj_rms": _rms(k),
                "v_proj_rms": _rms(v),
                "attn_out_rms": _rms(out),
                "attn_out_proj_rms": _rms(out_proj),
            }
            return out_proj, stats
        return out_proj


class ActionComparator(nnx.Module):
    """Transformer-like comparator that scores whether action A is better than B.

    Layer-aligned variant: each comparator layer cross-attends to the corresponding VLM layer features.

    Inputs:
      - action_a: [B, H, Da]  (action chunk)
      - action_b: [B, H, Da]
      - state:    [B, Ds]
      - vlm_raw_by_layer:  list of length L; each [B, N_raw, D]
      - vlm_core_by_layer: list of length L; each [B, N_core, D]

    Output:
      - logits: [B, 1]  (sigmoid -> P(A>B))
    """

    def __init__(
        self,
        *,
        hidden_dim: int,
        num_heads: int,
        depth: int,
        dropout: float,
        action_dim: int,
        action_horizon: int,
        vlm_dim: int,
        state_dim: int,
        rngs: nnx.Rngs,
        use_mlp_projections: bool = False,
    ):
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.depth = depth
        self.dropout = dropout
        self.action_dim = action_dim
        self.action_horizon = action_horizon
        self.vlm_dim = vlm_dim
        self.state_dim = state_dim

        # Projectors
        flat_dim = action_dim * action_horizon
        if use_mlp_projections:
            # Two-layer MLP for projections
            self.action_a_proj = nnx.Sequential(
                nnx.Linear(flat_dim, hidden_dim, rngs=rngs), nnx.silu,
                nnx.Linear(hidden_dim, hidden_dim, rngs=rngs), nnx.silu
            )
            self.action_b_proj = nnx.Sequential(
                nnx.Linear(flat_dim, hidden_dim, rngs=rngs), nnx.silu,
                nnx.Linear(hidden_dim, hidden_dim, rngs=rngs), nnx.silu
            )
            self.delta_proj = nnx.Sequential(
                nnx.Linear(flat_dim, hidden_dim, rngs=rngs), nnx.silu,
                nnx.Linear(hidden_dim, hidden_dim, rngs=rngs), nnx.silu
            )
            self.state_proj = nnx.Sequential(
                nnx.Linear(state_dim, hidden_dim, rngs=rngs), nnx.silu,
                nnx.Linear(hidden_dim, hidden_dim, rngs=rngs), nnx.silu
            )
        else:
            # Single-layer Linear (default)
            self.action_a_proj = nnx.Sequential(
                nnx.Linear(flat_dim, hidden_dim, rngs=rngs), nnx.silu
            )
            self.action_b_proj = nnx.Sequential(
                nnx.Linear(flat_dim, hidden_dim, rngs=rngs), nnx.silu
            )
            self.delta_proj = nnx.Sequential(
                nnx.Linear(flat_dim, hidden_dim, rngs=rngs), nnx.silu
            )
            self.state_proj = nnx.Sequential(
                nnx.Linear(state_dim, hidden_dim, rngs=rngs), nnx.silu
            )

        # Per-layer attentions
        self.self_attn = [
            _Attention(hidden_dim, num_heads, kv_dim=hidden_dim, rngs=rngs, dropout=dropout) for _ in range(depth)
        ]
        self.cross_raw = [
            _Attention(hidden_dim, num_heads, kv_dim=vlm_dim, rngs=rngs, dropout=dropout) for _ in range(depth)
        ]
        self.cross_core = [
            _Attention(hidden_dim, num_heads, kv_dim=vlm_dim, rngs=rngs, dropout=dropout) for _ in range(depth)
        ]
        # Layer norms: pre-attention and pre-ffn (pre-norm transformer)
        self.attn_ln = [
            _LayerNorm(hidden_dim) for _ in range(depth)
        ]
        self.ffn_ln = [
            _LayerNorm(hidden_dim) for _ in range(depth)
        ]
        self.alpha_core = [
            nnx.Param(jnp.array(0.0, dtype=jnp.float32)) for _ in range(depth)
        ]
        self.mix_proj = [
            nnx.Linear(3 * hidden_dim, hidden_dim, rngs=rngs) for _ in range(depth)
        ]
        # Post-concat normalization to stabilize scale before residual
        self.post_concat_ln = [
            _LayerNorm(hidden_dim) for _ in range(depth)
        ]
        self.ffn = [
            nnx.Sequential(
                nnx.Linear(hidden_dim, 4 * hidden_dim, rngs=rngs),
                nnx.silu,
                nnx.Linear(4 * hidden_dim, hidden_dim, rngs=rngs),
            )
            for _ in range(depth)
        ]

        # Head
        self.head = nnx.Sequential(
            nnx.Linear(hidden_dim, hidden_dim, rngs=rngs), nnx.silu, nnx.Linear(hidden_dim, 1, rngs=rngs)
        )
        # Summary normalization before head
        self.summary_ln = _LayerNorm(hidden_dim)

    def __call__(
        self,
        *,
        action_a: jnp.ndarray,
        action_b: jnp.ndarray,
        state: jnp.ndarray,
        vlm_raw_by_layer: list[jnp.ndarray],
        vlm_core_by_layer: list[jnp.ndarray],
        deterministic: bool = True,
        return_stats: bool = False,
    ) -> jnp.ndarray | tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
        # Flatten action chunks: [B, H, Da] -> [B, H*Da]
        def _rms(x):
            x = x.astype(jnp.float32)
            return jnp.sqrt(jnp.mean(jnp.square(x) + 1e-12))

        a_flat = einops.rearrange(action_a, "B H D -> B (H D)")
        b_flat = einops.rearrange(action_b, "B H D -> B (H D)")
        d_flat = a_flat - b_flat
        a = self.action_a_proj(a_flat)
        b = self.action_b_proj(b_flat)
        d = self.delta_proj(d_flat)
        s = self.state_proj(state)
        x = jnp.stack([a, b, d, s], axis=1)

        stats: dict[str, jnp.ndarray] = {}
        if return_stats:
            stats.update({
                "in_a_flat_rms": _rms(a_flat),
                "in_b_flat_rms": _rms(b_flat),
                "in_d_flat_rms": _rms(d_flat),
                "proj_a_rms": _rms(a),
                "proj_b_rms": _rms(b),
                "proj_d_rms": _rms(d),
                "proj_s_rms": _rms(s),
                "x_init_rms": _rms(x),
            })

        # No mask within 4 tokens self-attn
        self_mask = None
        for l in range(self.depth):
            x_attn = self.attn_ln[l](x)
            if return_stats:
                y_self, s_self = self.self_attn[l](x_attn, x_attn, mask=self_mask, deterministic=deterministic, return_stats=True)
                y_raw, s_raw = self.cross_raw[l](x_attn, vlm_raw_by_layer[l], mask=None, deterministic=deterministic, return_stats=True)
                y_core, s_core = self.cross_core[l](x_attn, vlm_core_by_layer[l], mask=None, deterministic=deterministic, return_stats=True)
                stats.update({
                    f"l{l}_x_attn_rms": _rms(x_attn),
                    f"l{l}_self_q_in_rms": s_self["q_in_rms"],
                    f"l{l}_self_kv_in_rms": s_self["kv_in_rms"],
                    f"l{l}_self_out_proj_rms": s_self["attn_out_proj_rms"],
                    f"l{l}_raw_kv_in_rms": s_raw["kv_in_rms"],
                    f"l{l}_raw_out_proj_rms": s_raw["attn_out_proj_rms"],
                    f"l{l}_core_kv_in_rms": s_core["kv_in_rms"],
                    f"l{l}_core_out_proj_rms": s_core["attn_out_proj_rms"],
                })
            else:
                y_self = self.self_attn[l](x_attn, x_attn, mask=self_mask, deterministic=deterministic)
                y_raw = self.cross_raw[l](x_attn, vlm_raw_by_layer[l], mask=None, deterministic=deterministic)
                y_core = self.cross_core[l](x_attn, vlm_core_by_layer[l], mask=None, deterministic=deterministic)
            y = jnp.concatenate([y_self, y_raw, y_core * jnp.tanh(self.alpha_core[l].value)], axis=-1)
            # Scale by sqrt(3) due to concatenation of three branches, then project and normalize
            y = self.mix_proj[l](y / jnp.sqrt(3.0))
            y = self.post_concat_ln[l](y)
            x = x + y
            x = x + self.ffn[l](self.ffn_ln[l](x))
            if return_stats:
                stats.update({
                    f"l{l}_y_after_mix_rms": _rms(y),
                    f"l{l}_x_after_block_rms": _rms(x),
                })

        # use first token as summary
        summary = x[:, 0]
        summary = self.summary_ln(summary)
        logits = self.head(summary)
        if return_stats:
            stats.update({
                "summary_rms": _rms(summary),
                "head_in_rms": _rms(summary),
                "logits_mean": jnp.mean(logits),
                "logits_abs_max": jnp.max(jnp.abs(logits)),
                "logits_rms": _rms(logits),
            })
            return logits, stats
        return logits

 