"""Tabular embedder using additive label encoding like TabICL.

This module provides:
1. TabularEmbedder: Row-wise embedder with additive target encoding
   - Embeds scalar features to high dimension
   - Applies row-wise attention across features
   - Adds targets/labels additively (not concatenated)
   
2. TabularACE: ACE model specialized for tabular data
   - Uses TabularEmbedder instead of standard Embedder
   - Compatible with ACE training pipeline
   - Requires passing block_mask to forward (like standard ACE)
"""

from typing import Optional, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.utils import DataAttr
from src.models.ace import AmortizedConditioningEngine
from src.models.modules import Transformer, MixtureGaussian
from src.models.row_attention import Encoder as RoPEEncoder
from src.models.masks import create_training_block_mask


class InducedSetAttentionBlock(nn.Module):
    """ISAB block for efficient column processing."""
    
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_inducing_points: int,
        dim_feedforward: int,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.num_inducing_points = num_inducing_points
        # Standard normal init (mean=0, std=0.02) for inducing vectors
        self.inducing_points = nn.Parameter(torch.empty(num_inducing_points, embed_dim))
        nn.init.normal_(self.inducing_points, mean=0.0, std=0.02)
        
        # MAB(I, X) - inducing points attend to input
        self.mab1 = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        # Pre-norm layers for first MAB/FFN
        self.norm1_q = nn.LayerNorm(embed_dim)
        self.norm1_kv = nn.LayerNorm(embed_dim)
        self.ff1 = nn.Sequential(
            nn.Linear(embed_dim, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, embed_dim),
            nn.Dropout(dropout)
        )
        self.norm1_ff = nn.LayerNorm(embed_dim)
        
        # MAB(X, H) - input attends to processed inducing points
        self.mab2 = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        # Pre-norm layers for second MAB/FFN
        self.norm2_q = nn.LayerNorm(embed_dim)
        self.norm2_kv = nn.LayerNorm(embed_dim)
        self.ff2 = nn.Sequential(
            nn.Linear(embed_dim, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, embed_dim),
            nn.Dropout(dropout)
        )
        self.norm2_ff = nn.LayerNorm(embed_dim)
    
    def forward(self, x: torch.Tensor, train_size: Optional[int] = None) -> torch.Tensor:
        B, N, D = x.shape
        
        # Expand inducing points for batch
        I = self.inducing_points.unsqueeze(0).expand(B, -1, -1)
        
        # Pre-norm MAB(I, X): normalize inputs before attention
        Iq = self.norm1_q(I)
        Xkv = self.norm1_kv(x)
        # Gate MAB1 to attend only to training rows if train_size is provided
        if train_size is not None:
            Xkv_gated = Xkv[:, :train_size, :]
        else:
            Xkv_gated = Xkv
        h, _ = self.mab1(Iq, Xkv_gated, Xkv_gated)
        h = I + h
        # FFN with pre-norm
        h = h + self.ff1(self.norm1_ff(h))
        
        # MAB(X, H) with pre-norm
        Xq = self.norm2_q(x)
        Hkv = self.norm2_kv(h)
        out, _ = self.mab2(Xq, Hkv, Hkv)
        out = x + out
        # FFN with pre-norm
        out = out + self.ff2(self.norm2_ff(out))
        
        return out


class ColumnProcessor(nn.Module):
    """TabICL-style TFcol across rows per column with train-only gating in MAB1."""
    
    def __init__(
        self,
        embed_dim: int,
        num_blocks: int = 1,
        num_heads: int = 4,
        num_inducing_points: int = 64,
        dim_feedforward: int = 256,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.blocks = nn.ModuleList([
            InducedSetAttentionBlock(
                embed_dim, num_heads, num_inducing_points, 
                dim_feedforward, dropout
            )
            for _ in range(num_blocks)
        ])
        
        # Learned weights and biases for scaling (like TabICL)
        self.out_weight = nn.Linear(embed_dim, embed_dim)
        self.ln_weight = nn.LayerNorm(embed_dim)
        
        self.out_bias = nn.Linear(embed_dim, embed_dim)
        self.ln_bias = nn.LayerNorm(embed_dim)
        
    def forward(self, U_bfrE: torch.Tensor, c_bfr1: torch.Tensor, train_size: int) -> torch.Tensor:
        """Process per-column sequences across rows and output W⊙c + B.

        Args:
            U_bfrE: (B*F, R, E) column embeddings across rows
            c_bfr1: (B*F, R, 1) scalar column values across rows
            train_size: number of training rows (e.g., Nc + Nb)
        Returns:
            (B*F, R, E) scaled embeddings
        """
        x = U_bfrE
        for block in self.blocks:
            x = block(x, train_size=train_size)

        # Generate weights and biases per cell
        weights = self.ln_weight(self.out_weight(x))
        biases = self.ln_bias(self.out_bias(x))

        # Apply scaling
        return c_bfr1 * weights + biases


class TabularEmbedder(nn.Module):
    """Tabular embedder with column processing and row-wise attention."""
    
    def __init__(
        self,
        num_features: int,
        embed_dim: int,
        num_cls_tokens: int = 4,
        nhead: int = 4,
        num_layers: int = 1,
        dim_feedforward: int = 256,
        dropout: float = 0.0,
        max_buffer_size: int = 32,
        ar_token_init_std: float = 0.02,
        num_isab_blocks: int = 1,
        num_inducing_points: int = 64,
        row_rope_base: int = 30000,
        # New: allow different heads for column vs row encoders
        col_nhead: int | None = None,
        row_nhead: int | None = None,
        # New: optionally concatenate CLS tokens for row embeddings
        concat_cls: bool = False,
    ):
        super().__init__()
        self.num_features = num_features
        self.embed_dim = embed_dim
        self.num_cls_tokens = num_cls_tokens
        self.max_buffer_size = max_buffer_size
        self.concat_cls = concat_cls

        self.feature_embedding = nn.Linear(1, embed_dim)
        self.target_encoder = nn.Linear(1, embed_dim)
        self.cls_tokens = nn.Parameter(torch.randn(num_cls_tokens, embed_dim) * 0.02)
        
        # Resolve separate heads (fallback to single nhead for backward compat)
        col_heads = col_nhead if col_nhead is not None else nhead
        row_heads = row_nhead if row_nhead is not None else nhead

        # Column processor for feature-wise processing
        self.column_processor = ColumnProcessor(
            embed_dim=embed_dim,
            num_blocks=num_isab_blocks,
            num_heads=col_heads,
            num_inducing_points=num_inducing_points,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        
        # RoPE-aware row encoder (replaces vanilla TransformerEncoder)
        self.row_encoder = RoPEEncoder(
            num_blocks=num_layers,
            d_model=embed_dim,
            nhead=row_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="gelu",
            norm_first=True,
            use_rope=True,
            rope_base=row_rope_base,
        )
        
        # Output projection/normalization for CLS outputs
        if self.concat_cls:
            self.out_norm_concat = nn.LayerNorm(embed_dim * num_cls_tokens)
        else:
            self.out_proj = nn.Linear(embed_dim * num_cls_tokens, embed_dim)
            self.out_norm = nn.LayerNorm(embed_dim)
        self.ar_tokens = nn.Parameter(
            torch.randn(max_buffer_size, embed_dim) * ar_token_init_std
        )

    def forward(
        self,
        batch: DataAttr
    ) -> torch.Tensor:
        """Process batch through embedder with column processing and row attention."""
        
        B = batch.xc.shape[0]
        Nc, Nb, Nt = batch.xc.shape[1], batch.xb.shape[1], batch.xt.shape[1]
        C = self.num_cls_tokens
        E = self.embed_dim
        
        # Concatenate all features across rows
        x_BRF = torch.cat([batch.xc, batch.xb, batch.xt], dim=1)  # (B, R, F)
        R = Nc + Nb + Nt
        F = x_BRF.shape[-1]

        # Build per-column sequences across rows: (B*F, R, 1)
        c_BFR1 = x_BRF.permute(0, 2, 1).contiguous().unsqueeze(-1).reshape(B * F, R, 1)
        # U = Lin(c): (B*F, R, E)
        U_BFRE = self.feature_embedding(c_BFR1.reshape(-1, 1)).reshape(B * F, R, E)
        # Train-only gating size (context + buffer rows)
        train_size = int(Nc + Nb)
        # Process columns across rows with ISAB
        V_BFRE = self.column_processor(U_BFRE, c_BFR1, train_size)
        # Reshape back to (B, R, F, E)
        x_BRFE = V_BFRE.reshape(B, F, R, E).permute(0, 2, 1, 3).contiguous()
        
        # Get CLS tokens and expand
        cls_11CE = self.cls_tokens.unsqueeze(0).unsqueeze(0)
        cls_BRCE = cls_11CE.expand(B, R, -1, -1)
        
        # Concatenate CLS tokens with processed features
        tokens_BRDE = torch.cat([cls_BRCE, x_BRFE], dim=2)  # D = C + F
        
        # Row-wise transformer
        T = B * R
        D = C + F
        tokens_TDE = tokens_BRDE.reshape(T, D, E)
        outputs_TDE = self.row_encoder(tokens_TDE)
        
        # Extract CLS outputs
        cls_outputs_TCE = outputs_TDE[:, :C, :]
        cls_outputs_TCxE = cls_outputs_TCE.reshape(T, C * E)
        if self.concat_cls:
            row_TE = self.out_norm_concat(cls_outputs_TCxE)  # (T, C*E)
            return row_TE.reshape(B, R, C * E)
        else:
            row_TE = self.out_norm(self.out_proj(cls_outputs_TCxE))  # (T, E)
            return row_TE.reshape(B, R, E)


class TabularACE(AmortizedConditioningEngine):
    """ACE model for tabular data using additive embeddings."""
    
    def __init__(
        self,
        num_features: int,
        embed_dim: int = 64,
        transformer_layers: int = 6,
        nhead: int = 4,
        dim_feedforward: int = 256,
        num_components: int = 20,
        max_buffer_size: int = 32,
        num_target_points: int = 256,
        targets_block_size_for_buffer_attend: int = 5,
        dropout: float = 0.0,
        num_isab_blocks: int = 1,
        num_inducing_points: int = 64,
        row_rope_base: int = 30000,
        # New: allow different heads for column vs row encoders in the embedder
        col_nhead: int | None = None,
        row_nhead: int | None = None,
        # New: row encoder depth (number of row blocks)
        row_num_blocks: int | None = None,
        # New: concatenate CLS embeddings before backbone
        concat_cls: bool = False,
        num_cls_tokens: int = 4,
        ff_factor: float | None = None,
    ):
        # Create tabular embedder that returns unified embeddings
        embedder = TabularEmbedder(
            num_features=num_features,
            embed_dim=embed_dim,
            num_cls_tokens=num_cls_tokens,
            nhead=nhead,
            num_layers=row_num_blocks if row_num_blocks is not None else 1,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            max_buffer_size=max_buffer_size,
            num_isab_blocks=num_isab_blocks,
            num_inducing_points=num_inducing_points,
            row_rope_base=row_rope_base,
            col_nhead=col_nhead,
            row_nhead=row_nhead,
            concat_cls=concat_cls,
        )
        
        # Effective backbone dim: E or C*E if concatenating CLS
        eff_dim_model = embed_dim * num_cls_tokens if concat_cls else embed_dim
        # Compute FF dims if ff_factor provided
        if ff_factor is not None:
            ff_backbone = int(eff_dim_model * float(ff_factor))
        else:
            ff_backbone = dim_feedforward

        backbone = Transformer(
            num_layers=transformer_layers,
            dim_model=eff_dim_model,
            num_head=nhead,
            dim_feedforward=ff_backbone,
            dropout=dropout,
        )
        
        head = MixtureGaussian(
            dim_y=1,
            dim_model=eff_dim_model,
            dim_feedforward=ff_backbone,
            num_components=num_components,
        )
        
        super().__init__(
            embedder=embedder,
            backbone=backbone,
            head=head,
            max_buffer_size=max_buffer_size,
            num_target_points=num_target_points,
            targets_block_size_for_buffer_attend=targets_block_size_for_buffer_attend,
        )
        
        self.tabular_embedder = embedder
        self.concat_cls = concat_cls
        self.num_cls_tokens = num_cls_tokens
        self.effective_dim_model = eff_dim_model
    
    def forward(self, batch: DataAttr, block_mask):
        """Forward pass."""
        B, Nc, Nb, Nt = batch.xc.shape[0], batch.xc.shape[1], batch.xb.shape[1], batch.xt.shape[1]

        # Base row embeddings (E or C*E)
        embeddings = self.tabular_embedder(batch)

        # Encode targets to match row dim
        yc_enc = self.tabular_embedder.target_encoder(batch.yc)  # [B,Nc,E]
        yb_enc = self.tabular_embedder.target_encoder(batch.yb)  # [B,Nb,E]
        if self.concat_cls:
            yc_enc = yc_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,Nc,C*E]
            yb_enc = yb_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,Nb,C*E]
            ar_tokens = self.tabular_embedder.ar_tokens[:Nb].repeat_interleave(self.num_cls_tokens, dim=-1)  # [Nb,C*E]
        else:
            ar_tokens = self.tabular_embedder.ar_tokens[:Nb]  # [Nb,E]

        # Add to context and buffer rows (+ AR tokens on buffer rows)
        embeddings[:, :Nc, :] = embeddings[:, :Nc, :] + yc_enc
        embeddings[:, Nc:Nc+Nb, :] = embeddings[:, Nc:Nc+Nb, :] + yb_enc + ar_tokens.unsqueeze(0).expand(B, -1, -1)
        
        z, _ = self.backbone(embeddings, block_mask=block_mask)
        
        # Split outputs
        _, zt = torch.split(z, [Nc+Nb, Nt], dim=1)
        
        return self.head(zt, batch.yt)


def test_tabular_embedder():
    """Test the tabular embedder."""
    torch.manual_seed(42)
    
    batch_size = 2
    nc, nb, nt = 10, 8, 5
    num_features = 7
    embed_dim = 32
    
    batch = DataAttr(
        xc=torch.randn(batch_size, nc, num_features),
        yc=torch.randn(batch_size, nc, 1),
        xb=torch.randn(batch_size, nb, num_features),
        yb=torch.randn(batch_size, nb, 1),
        xt=torch.randn(batch_size, nt, num_features),
        yt=None
    )
    
    print("="*60)
    print("Testing Tabular Embedder")
    print("="*60)
    
    embedder = TabularEmbedder(
        num_features=num_features,
        embed_dim=embed_dim,
        num_cls_tokens=4,
        nhead=4,
        num_layers=2,
        dim_feedforward=128,
        dropout=0.0,
        max_buffer_size=8,
        ar_token_init_std=0.02
    )
    
    print("Testing forward pass...")
    output = embedder.forward(batch)
    print(f"Output shape: {output.shape}")
    print(f"Expected: ({batch_size}, {nc + nb + nt}, {embed_dim})")
    print(f"Mean: {output.mean():.4f}, Std: {output.std():.4f}")
    
    print(f"\nModel components:")
    print(f"  Feature embedding: Linear(1 → {embed_dim})")
    print(f"  Target encoder: Linear(1 → {embed_dim})")
    print(f"  CLS tokens shape: {embedder.cls_tokens.shape}")
    print(f"  Buffer position tokens shape: {embedder.ar_tokens.shape}")
    print(f"  Column processor: {len(embedder.column_processor.blocks)} ISAB block(s) with {embedder.column_processor.blocks[0].num_inducing_points} inducing points")
    print(f"  Row transformer: {len(embedder.row_transformer.layers)} layer(s)")
    print(f"  Transformer input: {embedder.num_cls_tokens} CLS + {num_features} features")
    
    print("\nTest completed successfully!")
    print("\nProcess: Embed → Column ISAB → Concatenate CLS → Row Transform")
    
    return embedder, batch
    
def test_tabular_ace():
    """Test TabularACE model."""
    torch.manual_seed(42)
    
    print("\n" + "="*60)
    print("Testing TabularACE")
    print("="*60)
    
    # Setup
    batch_size = 2
    nc, nb, nt = 10, 8, 5
    num_features = 7
    
    batch = DataAttr(
        xc=torch.randn(batch_size, nc, num_features),
        yc=torch.randn(batch_size, nc, 1),
        xb=torch.randn(batch_size, nb, num_features),
        yb=torch.randn(batch_size, nb, 1),
        xt=torch.randn(batch_size, nt, num_features),
        yt=torch.randn(batch_size, nt, 1)
    )
    
    # Create TabularACE
    model = TabularACE(
        num_features=num_features,
        embed_dim=32,
        transformer_layers=2,
        nhead=4,
        dim_feedforward=128,
        num_components=3,
        max_buffer_size=8,
        num_target_points=5,
        targets_block_size_for_buffer_attend=2,
    )
    
    print(f"Model initialized:")
    print(f"  Type: TabularACE")
    print(f"  Features: {num_features}")
    print(f"  Transformer layers: {len(model.backbone.layers)}")
    print(f"  Mixture components: {model.head.num_components}")
    
    # Create training mask
    print("\nCreating training mask...")
    total_len = nc + nb + nt
    mask = create_training_block_mask(
        current_total_q_len=total_len,
        current_total_kv_len=total_len,
        current_context_section_len=nc,
        current_buffer_section_len=nb,
        device='cpu',
        attending_chunks=2  # How many target chunks can see buffer
    )
    print(f"  Mask created for: context={nc}, buffer={nb}, target={nt}")
    
    # Test forward pass
    print("\nTesting forward pass...")
    loss_attr = model(batch, mask)
    
    print(f"Loss:")
    print(f"  Loss: {loss_attr.loss.mean().item():.4f}")
    print(f"  Log likelihood: {loss_attr.log_likelihood.mean().item():.4f}")
    
    print("\nTabularACE test completed!")


if __name__ == "__main__":
    test_tabular_embedder()
    test_tabular_ace()
