"""Implementation of Future Predictor in JAX/NNX."""
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from typing import Any, Callable, Optional
import math


def apply_rotary_pos_emb_2d(q, k, cos, sin):
    """Apply 2D rotary position embedding to query and key tensors.
    
    Args:
        q: Query tensor [batch, heads, seq_len, head_dim]
        k: Key tensor [batch, heads, seq_len, head_dim]
        cos: Cosine values [seq_len, head_dim]
        sin: Sine values [seq_len, head_dim]
    """
    # Split head_dim into pairs for rotation
    head_dim = q.shape[-1]
    assert head_dim % 2 == 0, "Head dimension must be even for RoPE"
    # Reshape for rotation: [..., head_dim] -> [..., head_dim//2, 2]
    q_pairs = q.reshape(*q.shape[:-1], head_dim // 2, 2)
    k_pairs = k.reshape(*k.shape[:-1], head_dim // 2, 2)
    
    # Extract real and imaginary parts
    q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1]
    k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1]
    
    # Reshape cos/sin for broadcasting
    cos = cos.reshape(cos.shape[0], head_dim // 2)
    sin = sin.reshape(sin.shape[0], head_dim // 2)
    
    # Apply rotation
    q_real_rot = q_real * cos[None, None, :, :] - q_imag * sin[None, None, :, :]
    q_imag_rot = q_real * sin[None, None, :, :] + q_imag * cos[None, None, :, :]
    k_real_rot = k_real * cos[None, None, :, :] - k_imag * sin[None, None, :, :]
    k_imag_rot = k_real * sin[None, None, :, :] + k_imag * cos[None, None, :, :]
    
    # Recombine pairs
    q_rot_pairs = jnp.stack([q_real_rot, q_imag_rot], axis=-1)
    k_rot_pairs = jnp.stack([k_real_rot, k_imag_rot], axis=-1)
    
    # Reshape back to original shape
    q_rot = q_rot_pairs.reshape(*q.shape)
    k_rot = k_rot_pairs.reshape(*k.shape)
    
    return q_rot, k_rot


def create_2d_rope_embeddings(max_time, embed_dim, base=10000):
    """Create 2D RoPE embeddings for time dimension only.
    
    Args:
        max_time: Maximum number of time steps
        embed_dim: Embedding dimension (must be even)
        base: Base for frequency calculation
    """
    assert embed_dim % 2 == 0, "Embed dim must be even for RoPE"
    
    # Create frequency bands
    freqs = 1.0 / (base ** (jnp.arange(0, embed_dim, 2).astype(jnp.float32) / embed_dim))
    
    # Create position grid
    time_pos = jnp.arange(max_time, dtype=jnp.float32)
    
    # Compute angles for time dimension
    angles = jnp.outer(time_pos, freqs)  # [max_time, embed_dim//2]
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)
    
    return cos, sin


class MLP(nnx.Module):
    def __init__(self, n_embd, *, rngs):
        self.dense1 = nnx.Linear(n_embd, 4 * n_embd, rngs=rngs)
        self.dense2 = nnx.Linear(4 * n_embd, n_embd, rngs=rngs)
        
    def __call__(self, x):
        x = self.dense1(x)
        x = nnx.gelu(x)
        x = self.dense2(x)
        return x
    

class CrossAttentionBlock(nnx.Module):
    """Cross-attention block for conditioning on action and language"""
    
    def __init__(self, hidden_dim, n_heads, *, rngs):
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads
        assert hidden_dim % n_heads == 0, f"Hidden dim {hidden_dim} not divisible by n_heads {n_heads}"
        
        # Cross-attention layers
        self.q_proj = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
        self.k_proj = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
        self.v_proj = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
        self.out_proj = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
        
        # Layer norms
        self.ln_q = nnx.LayerNorm(hidden_dim, rngs=rngs)
        self.ln_kv = nnx.LayerNorm(hidden_dim, rngs=rngs)
        
    def __call__(self, query_features, key_value_features, attention_mask=None):
        """
        Args:
            query_features: [batch, query_len, hidden_dim] - features to be conditioned
            key_value_features: [batch, kv_len, hidden_dim] - conditioning features
            attention_mask: [batch, kv_len] - 1 for valid tokens, 0 for padding
        Returns:
            Conditioned features: [batch, query_len, hidden_dim]
        """
        B, Q_len, D = query_features.shape
        B_kv, KV_len, D_kv = key_value_features.shape
        
        assert B == B_kv, f"Batch size mismatch: {B} vs {B_kv}"
        assert D == D_kv == self.hidden_dim, f"Hidden dim mismatch: {D}, {D_kv} vs {self.hidden_dim}"
        
        # Layer norms
        q_normed = self.ln_q(query_features)
        kv_normed = self.ln_kv(key_value_features)
        
        # Project to Q, K, V
        q = self.q_proj(q_normed)  # [batch, query_len, hidden_dim]
        k = self.k_proj(kv_normed)  # [batch, kv_len, hidden_dim]
        v = self.v_proj(kv_normed)  # [batch, kv_len, hidden_dim]
        
        # Reshape for multi-head attention
        q = q.reshape(B, Q_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)  # [batch, n_heads, query_len, head_dim]
        k = k.reshape(B, KV_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)  # [batch, n_heads, kv_len, head_dim]
        v = v.reshape(B, KV_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)  # [batch, n_heads, kv_len, head_dim]
        
        # Scaled dot-product attention
        scale = 1.0 / jnp.sqrt(self.head_dim)
        attn_weights = jnp.matmul(q, k.transpose(0, 1, 3, 2)) * scale  # [batch, n_heads, query_len, kv_len]
        
        # Apply attention mask if provided
        if attention_mask is not None:
            # Expand mask for multi-head attention
            mask_expanded = attention_mask[:, None, None, :]  # [batch, 1, 1, kv_len]
            attn_weights = jnp.where(mask_expanded == 0, -1e10, attn_weights)
        
        attn_weights = jax.nn.softmax(attn_weights, axis=-1)
        
        # Attention output
        out = jnp.matmul(attn_weights, v)  # [batch, n_heads, query_len, head_dim]
        out = out.transpose(0, 2, 1, 3).reshape(B, Q_len, self.hidden_dim)  # [batch, query_len, hidden_dim]
        
        # Output projection
        out = self.out_proj(out)
        
        # Residual connection
        out = query_features + out
        
        return out


class TransformerBlock(nnx.Module):
    """Causal transformer block using causal self-attention with 2D RoPE"""
    def __init__(self, n_embd, n_head, max_time=5, *, rngs):
        self.ln1 = nnx.LayerNorm(n_embd, rngs=rngs)
        self.qkv_proj = nnx.Linear(n_embd, 3 * n_embd, rngs=rngs)
        self.out_proj = nnx.Linear(n_embd, n_embd, rngs=rngs)
        self.ln2 = nnx.LayerNorm(n_embd, rngs=rngs)
        self.mlp = MLP(n_embd, rngs=rngs)
        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.max_time = max_time
        
    def __call__(self, x):
        # x shape: [batch, time, space, channels]
        B, T, S, C = x.shape
        
        # Reshape to [batch, sequence, channels] where sequence = time * space
        x_flat = x.reshape(B, T * S, C)
        seq_len = T * S
        
        # First block: causal self-attention with residual connection
        residual = x_flat
        x_norm = self.ln1(x_flat)
        
        # Self-attention logic
        qkv = self.qkv_proj(x_norm)
        q, k, v = jnp.split(qkv, 3, axis=-1)
        
        # Reshape for multi-head attention
        q = q.reshape(B, seq_len, self.n_head, self.head_dim).transpose(0, 2, 1, 3)  # (B, nh, seq_len, hs)
        k = k.reshape(B, seq_len, self.n_head, self.head_dim).transpose(0, 2, 1, 3)  # (B, nh, seq_len, hs)
        v = v.reshape(B, seq_len, self.n_head, self.head_dim).transpose(0, 2, 1, 3)  # (B, nh, seq_len, hs)
        
        # Apply 2D RoPE embeddings (only time dimension)
        cos, sin = create_2d_rope_embeddings(self.max_time, self.head_dim, base=10000)
        
        # Create position indices for flattened sequence (only time indices matter)
        token_indices = jnp.tile(jnp.arange(S), T)[:seq_len]
        
        # Get RoPE embeddings for current positions
        cos_pos = cos[token_indices]  # [seq_len, head_dim//2]
        sin_pos = sin[token_indices]  # [seq_len, head_dim//2]
        
        # Duplicate for full head_dim (since we're only using time dimension)
        # cos_combined = jnp.tile(cos_pos, (1, 2))[:, :self.head_dim]  # [seq_len, head_dim]
        # sin_combined = jnp.tile(sin_pos, (1, 2))[:, :self.head_dim]  # [seq_len, head_dim]
        
        # Apply RoPE to q and k
        q, k = apply_rotary_pos_emb_2d(q, k, cos_pos, sin_pos)
        
        # Create causal mask for next token prediction
        mask = jnp.tril(jnp.ones((seq_len, seq_len)))
        mask = mask.reshape(1, 1, seq_len, seq_len)
        
        # Scaled dot-product attention with causal mask
        scale = 1.0 / jnp.sqrt(self.head_dim)
        attn_weights = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) * scale
        
        # Apply causal mask (set masked positions to large negative value)
        attn_weights = jnp.where(mask == 0, -1e10, attn_weights)
        attn_weights = jax.nn.softmax(attn_weights, axis=-1)
        
        # Attention output
        y = jnp.matmul(attn_weights, v)
        y = y.transpose(0, 2, 1, 3).reshape(B, seq_len, self.n_embd)
        
        # Output projection
        y = self.out_proj(y)
        
        # Residual connection
        x_flat = residual + y
        
        # Second block: MLP with residual connection
        residual = x_flat
        x_flat = self.ln2(x_flat)
        x_flat = residual + self.mlp(x_flat)
        
        # Reshape back to [batch, time, space, channels]
        x = x_flat.reshape(B, T, S, C)
        
        return x


class FuturePredictor(nnx.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, n_head, n_layers, 
                 max_time=32, *, rngs):
        
        # Core parameters
        self.hidden_dim = hidden_dim
        self.max_time = max_time
        
        # Input projection
        self.input_proj = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
        self.lang_input_proj = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
        
        # Cross-attention for language conditioning
        self.lang_cross_attn = CrossAttentionBlock(hidden_dim, n_head, rngs=rngs)
        
        # Transformer blocks
        self.blocks = {}
        for i in range(n_layers):
            self.blocks[f'block_{i}'] = TransformerBlock(
                hidden_dim, n_head, max_time, rngs=rngs
            )
        
        # Output layers
        self.ln_final = nnx.LayerNorm(hidden_dim, rngs=rngs)
        self.output_proj = nnx.Linear(hidden_dim, output_dim, rngs=rngs)
        
    def __call__(self, current_images, language_features=None, language_mask=None):
        """
        Args:
            current_images: [batch, time_steps, tokens_per_frame, features] - current visual features
            language_features: [batch, lang_tokens, features] - language conditioning features (optional)
            language_mask: [batch, lang_tokens] - mask for language features, 1 for valid tokens, 0 for padding (optional)
        Returns:
            Predicted future features: [batch, time_steps, tokens_per_frame, output_dim]
        """
        # current_images shape: [batch, time_steps, tokens_per_frame, features]
        B, T, N, F = current_images.shape
        
        # Project input to hidden dimension
        x_flat = current_images.reshape(B * T * N, F)
        x_flat = self.input_proj(x_flat)
        x = x_flat.reshape(B, T, N, self.hidden_dim)
        
        # Apply language conditioning if provided
        if language_features is not None:
            language_features = self.lang_input_proj(language_features)
            # Reshape for cross-attention: [batch, time*tokens, hidden_dim]
            x_for_attn = x.reshape(B, T * N, self.hidden_dim)
            x_for_attn = self.lang_cross_attn(x_for_attn, language_features, attention_mask=language_mask)
            x = x_for_attn.reshape(B, T, N, self.hidden_dim)
        
        # Apply transformer blocks
        for i in range(len(self.blocks)):
            x = self.blocks[f'block_{i}'](x)
            
        # Apply final layer norm
        x = self.ln_final(x)
        
        # Project to output dimension
        x_flat = x.reshape(B * T * N, self.hidden_dim)
        x_flat = self.output_proj(x_flat)
        x = x_flat.reshape(B, T, N, F)
        
        return x