# import math
# import torch
# import torch.nn as nn
# import logging
# import torch.nn.functional as F
# from utils.data_utils import Example, Feature
# from model.architectures.set_transformer import ISAB, PMA
# from utils.transformer_utils import get_tokenizer, get_model

# logger = logging.getLogger(__name__)

# class FeatureEmbedding(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.dtype_embed = nn.Embedding(2, config.model_dim)
#         self.desc_proj = nn.Linear(config.embedding_dim, config.model_dim)
#         if getattr(config, 'use_set_transformer', False):
#             self.class_pool = nn.Sequential(
#                 ISAB(config.model_dim, config.model_dim, config.num_heads, config.num_inds),
#                 PMA(config.model_dim, config.num_heads, 1)
#             )
#         else:
#             self.class_pool = None
#         self.embedding_cache = {}
#         # Add dummy tensor for device tracking
#         self.register_buffer('_dummy_tensor', torch.zeros(1))
        
#     def forward(self, feature: Feature):
#         # Get the device from the dummy tensor
#         device = self._dummy_tensor.device
        
#         # Check cache first
#         cache_key = id(feature)
#         if cache_key in self.embedding_cache:
#             # Always move cached embeddings to current device
#             return self.embedding_cache[cache_key].to(device)
            
#         # 1) Description embedding -> model_dim (ensure on correct device)
#         # Move embedding to device first, then compute projection
#         embedding = feature.description_embedding[self.config.embedding_model]
#         if isinstance(embedding, torch.Tensor):
#             desc = embedding.to(device)
#         else:
#             desc = torch.tensor(embedding, device=device)
        
#         desc = self.desc_proj(desc)
        
#         # 2) Dtype embedding
#         dtype_idx = 0 if feature.dtype == 'real' else 1
#         dt = self.dtype_embed(torch.tensor(dtype_idx, device=device))
#         out = desc + dt
        
#         # 3) If categorical and pooling enabled, add pooled category prototypes
#         if feature.dtype == 'categorical' and self.class_pool is not None:
#             cat_embs = []
#             for cat in feature.categories:
#                 cat_emb = feature.categories_embedding[self.config.embedding_model].get(cat)
#                 if cat_emb is not None:
#                     # Ensure embeddings are on the correct device
#                     if isinstance(cat_emb, torch.Tensor):
#                         cat_embs.append(cat_emb.to(device))
#                     else:
#                         cat_embs.append(torch.tensor(cat_emb, device=device))
                        
#             if cat_embs:
#                 cat_tensor = torch.stack(cat_embs, dim=0).unsqueeze(0)  # [1, C, emb_dim]
#                 pooled = self.class_pool(cat_tensor).squeeze(0).squeeze(0)  # [model_dim]
#                 out = out + pooled
                
#         # Store detached tensor in cache (on same device)
#         self.embedding_cache[cache_key] = out.detach()
#         return out        

# class FeatureValuePairEmbedding(nn.Module):
#     """
#     Embeds one (feature, value) pair into a token embedding for BERT.
#     Output dimension: config.model_dim
#     """
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.feature_embed = FeatureEmbedding(config)
#         self.real_proj = nn.Linear(1, config.model_dim)
#         self.missing_embed = nn.Embedding(1, config.model_dim)
#         self.embed_proj = nn.Linear(config.model_dim * 2, config.model_dim)
#         self.register_buffer('_dummy_tensor', torch.zeros(1))
#         # Cache for categorical value embeddings
#         self.cat_value_cache = {}
    
#     @property
#     def device(self):
#         return self._dummy_tensor.device
    
#     def forward(self, feature: Feature, value):
#         # Get device consistently
#         device = self._dummy_tensor.device
        
#         # 1) Feature description + dtype -> model_dim
#         f_emb = self.feature_embed(feature)
#         # Ensure f_emb is on the right device
#         f_emb = f_emb.to(device)

#         # 2) Value embedding
#         if value is None or (isinstance(value, float) and math.isnan(value)):
#             # Missing or NaN value
#             v_emb = self.missing_embed(torch.zeros(1, dtype=torch.long, device=device)).squeeze(0)
#         elif feature.dtype == 'real':
#             # Real value: project scalar
#             v_emb = self.real_proj(torch.tensor([[value]], device=device, dtype=torch.float)).squeeze(0)
#         else:
#             # Categorical value: check cache first
#             cache_key = (id(feature), value)
#             if cache_key in self.cat_value_cache:
#                 v_emb = self.cat_value_cache[cache_key].to(device)
#             else:
#                 # Categorical value: clamp index into valid range
#                 idx = int(value)
#                 num_cats = len(feature.categories)
#                 if num_cats > 0:
#                     # clamp between 0 and num_cats-1
#                     idx = max(0, min(idx, num_cats - 1))
#                     cat = feature.categories[idx]
#                     # Get embedding and explicitly move to device
#                     emb = feature.categories_embedding[self.config.embedding_model][cat]
#                     v_emb = torch.tensor(emb, device=device) if not isinstance(emb, torch.Tensor) else emb.to(device)
#                 else:
#                     # no categories defined, treat as missing
#                     v_emb = self.missing_embed(torch.zeros(1, dtype=torch.long, device=device)).squeeze(0)
                
#                 # Cache the result (on CPU)
#                 self.cat_value_cache[cache_key] = v_emb.detach().cpu()

#         # 3) Combine and project → Feature-Value Pair Embedding Vector
#         combined = torch.cat([f_emb, v_emb], dim=-1)
#         return self.embed_proj(combined)

# class PEQRowEmbedding(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.feature_value_embed = FeatureValuePairEmbedding(config)
#         self.peq_layer = nn.Sequential(
#             nn.LayerNorm(config.model_dim),
#             nn.Linear(config.model_dim, config.model_dim),
#             nn.ReLU()
#         )
#         # Cache for row embeddings
#         self.row_cache = {}
#         # Add dummy tensor for device tracking
#         self.register_buffer('_dummy_tensor', torch.zeros(1))
        
#     def forward(self, features: list[Feature], values: list[float]):
#         # Get device from dummy tensor
#         device = self._dummy_tensor.device
        
#         # Check cache first
#         cache_key = (tuple(id(f) for f in features), tuple(values))
#         if cache_key in self.row_cache:
#             # Move cached result to current device
#             return self.row_cache[cache_key].to(device)
        
#         try:
#             # Process all feature-value pairs and ensure they're on the device
#             embs = []
#             for f, v in zip(features, values):
#                 # Get embedding and ensure it's on the device
#                 emb = self.feature_value_embed(f, v).to(device)
#                 embs.append(emb)
                
#             # Stack all embeddings
#             if embs:
#                 stacked = torch.stack(embs, dim=0)
#                 result = self.peq_layer(stacked)
                
#                 # Cache result (on CPU to save memory)
#                 self.row_cache[cache_key] = result.detach().cpu()
                
#                 return result
#             else:
#                 # Return empty tensor if no embeddings
#                 return torch.zeros((0, self.config.model_dim), device=device)
#         except Exception as e:
#             logger.error(f"Error in PEQRowEmbedding.forward: {e}")
#             return torch.zeros((len(features), self.config.model_dim), device=device)

# class MDNRegressionHead(nn.Module):
#     """
#     Mixture Density Network head producing a mixture of Gaussians.
#     Outputs (pi, mu, sigma) each of shape [batch, num_mixtures].
#     """
#     def __init__(self, hidden_dim, num_mixtures):
#         super().__init__()
#         self.num_mixtures = num_mixtures
#         self.pi_layer = nn.Linear(hidden_dim, num_mixtures)
#         self.mu_layer = nn.Linear(hidden_dim, num_mixtures)
#         self.sigma_layer = nn.Linear(hidden_dim, num_mixtures)

#     def forward(self, h):
#         pi = F.softmax(self.pi_layer(h), dim=-1)
#         mu = self.mu_layer(h)
#         sigma = torch.exp(self.sigma_layer(h))
#         return pi, mu, sigma


# class PermutationEquivariantClassifier(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.embedding_dim = config.embedding_dim
        
#         # Project input features to right dimension
#         self.input_proj = nn.Linear(config.model_dim, self.embedding_dim)
        
#         # Number of layers (with default)
#         num_layers = getattr(config, 'num_equivariant_layers', 2)
        
#         # Build equivariant layers more efficiently
#         first_block = EquivariantBlock(self.embedding_dim * 2, self.embedding_dim)
#         other_blocks = [EquivariantBlock(self.embedding_dim, self.embedding_dim) for _ in range(num_layers - 1)]
#         self.equivariant_layers = nn.ModuleList([first_block] + other_blocks)
        
#         # Final projection to scalar logits
#         self.logit_proj = nn.Linear(self.embedding_dim, 1)
        
#     def forward(self, feature_rep, class_embeddings):
#         # Project feature_rep to embedding_dim
#         feature_rep = self.input_proj(feature_rep)  # [embedding_dim]
        
#         # Add batch dimension if needed
#         if feature_rep.dim() == 1:
#             feature_rep = feature_rep.unsqueeze(0)  # [1, embedding_dim]
        
#         batch_size = feature_rep.size(0)
#         num_classes = class_embeddings.size(0)
        
#         # Expand feature representation for each class - more efficiently
#         expanded_feature = feature_rep.unsqueeze(1).expand(-1, num_classes, -1)
        
#         # Expand class embeddings for each batch - more efficiently
#         expanded_classes = class_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        
#         # Concatenate along feature dimension
#         combined = torch.cat([expanded_feature, expanded_classes], dim=-1)
        
#         # Process through equivariant layers
#         x = combined
#         for layer in self.equivariant_layers:
#             x = layer(x)
        
#         # Project to get logits for each class
#         logits = self.logit_proj(x).squeeze(-1)
        
#         return logits


# class EquivariantBlock(nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super().__init__()
        
#         # Shared MLPs for permutation equivariance
#         self.phi = nn.Sequential(
#             nn.Linear(input_dim, output_dim),
#             nn.LayerNorm(output_dim),
#             nn.ReLU()
#         )
        
#         # Optional: add attention between classes
#         self.attention = nn.MultiheadAttention(
#             embed_dim=output_dim,
#             num_heads=4,
#             batch_first=True
#         )
        
#         # Layer norm and residual connection
#         self.norm = nn.LayerNorm(output_dim)
        
#     def forward(self, x):
#         # Apply shared MLP to each class independently
#         x = self.phi(x)
        
#         # Apply self-attention across classes
#         attn_out, _ = self.attention(x, x, x)
        
#         # Add residual connection and normalize
#         x = self.norm(attn_out + x)
        
#         return x


# # ——— PEQ-BERT layer and integration ———
# class PEQBertTransformerLayer(nn.Module):
#     """
#     Replaces BertSelfAttention with a permutation-equivariant block:
#       ISAB (Induced Self-Attention) → PMA (Pooling by Multihead Attention).
#     """
#     def __init__(self, config):
#         super().__init__()
#         self.isab = ISAB(
#             embed_dim=config.model_dim,
#             num_inds=config.num_inds,
#             num_heads=config.num_heads,
#         )
#         self.pma = PMA(
#             embed_dim=config.model_dim,
#             num_heads=config.num_heads,
#             num_seeds=1,
#         )

#     def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
#         # hidden_states: [batch_size, seq_len, model_dim]
#         x = self.isab(hidden_states)
#         x = self.pma(x)
#         seq_len = hidden_states.size(1)
#         peq_out = x.expand(-1, seq_len, -1)
#         if output_attentions:
#             return (peq_out, None)
#         return (peq_out,)

# class FSLModel(nn.Module):
#     """
#     PEQ-BERT few-shot model that matches the figure:
#     1) Tokenize description into BERT embeddings.
#     2) Append one PEQ token per feature-value for each few-shot row and the target row.
#     3) Run through BERT.
#     4) Extract the hidden state at the target column position.
#     5) Regression via MDN, classification via dot-product with class prototypes.
#     """
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.row_embed = PEQRowEmbedding(config)
#         self.tokenizer = get_tokenizer(config.model_name)
#         self.model = get_model(config.model_name)
#         # PEQ-BERT hook
#         if getattr(config, 'use_peq_bert', False):
#             self._replace_attention_layers()
#         self.temperature = getattr(config, 'temperature', 1.0)
#         self.class_head = nn.Linear(config.model_dim, config.model_dim)
#         self.reg_head = MDNRegressionHead(config.model_dim, config.num_mixtures)
#         self.register_buffer('_dummy_tensor', torch.zeros(1))
#         self.desc_cache = {}
#         self._embedding_layer = None
#         self.example_cache = {}

#     def _replace_attention_layers(self):
#         """
#         Walk through each Transformer layer and swap out the .self
#         module in BertAttention for our PEQBertTransformerLayer.
#         """
#         for layer in self.model.encoder.layer:
#             layer.attention.self = PEQBertTransformerLayer(self.config)

#     @property
#     def embedding_layer(self):
#         if self._embedding_layer is None:
#             self._embedding_layer = self.model.get_embedding_layer()
#         return self._embedding_layer

#     @property
#     def device(self):
#         return self._dummy_tensor.device

#     def _prepare_example(self, example: Example):
#         device = self._dummy_tensor.device
#         cache_key = id(example)
#         if cache_key in self.example_cache:
#             seq, mask, pos = self.example_cache[cache_key]
#             return seq.to(device), mask.to(device), pos
#         desc_key = example.description
#         if desc_key in self.desc_cache:
#             desc_emb = self.desc_cache[desc_key].to(device)
#         else:
#             toks = self.tokenizer(example.description, return_tensors="pt")
#             input_ids = toks.input_ids.to(device)
#             desc_emb = self.embedding_layer(input_ids).squeeze(0)
#             self.desc_cache[desc_key] = desc_emb.detach().cpu()
#             desc_emb = desc_emb.to(device)
#         fs_embs = []
#         for i, row in enumerate(example.fewshot_rows):
#             try:
#                 emb = self.row_embed(example.features, row).to(device)
#             except Exception as e:
#                 logger.warning(f"Error in few-shot row {i} embedding: {e}")
#                 emb = torch.zeros((len(example.features), self.config.model_dim), device=device)
#             fs_embs.append(emb)
#         excluded = set(example.missing_column_ids + [example.target_column_id])
#         tgt_vals = [v if idx not in excluded else None for idx, v in enumerate(example.target_row)]
#         try:
#             tgt_emb = self.row_embed(example.features, tgt_vals).to(device)
#         except Exception as e:
#             logger.warning(f"Error in target row embedding: {e}")
#             tgt_emb = torch.zeros((len(example.features), self.config.model_dim), device=device)
#         seq = torch.cat([desc_emb] + fs_embs + [tgt_emb], dim=0)
#         attn_mask = torch.ones(1, seq.size(0), device=device)
#         desc_len = desc_emb.size(0)
#         fs_len = sum(m.size(0) for m in fs_embs)
#         target_pos = desc_len + fs_len + example.target_column_id
#         self.example_cache[cache_key] = (
#             seq.unsqueeze(0).detach().cpu(),
#             attn_mask.detach().cpu(),
#             target_pos
#         )
#         max_allowed = self.model.config.max_position_embeddings
#         return seq.unsqueeze(0), attn_mask, target_pos

#     def _prepare_batch(self, batch: list[Example]):
#         device = self._dummy_tensor.device
#         processed = [self._prepare_example(ex) for ex in batch]
#         embeds, masks, positions = zip(*processed)
#         max_len = max(x.size(1) for x in embeds)
#         padded_embeds = torch.zeros((len(batch), max_len, embeds[0].size(2)), device=device)
#         padded_masks = torch.zeros((len(batch), max_len), device=device)
#         for i, (emb, mask) in enumerate(zip(embeds, masks)):
#             seq_len = emb.size(1)
#             padded_embeds[i, :seq_len] = emb.squeeze(0)
#             padded_masks[i, :seq_len] = mask.squeeze(0)
#         if padded_embeds.size(1) > self.model.config.max_position_embeddings:
#             padded_embeds = padded_embeds[:, :self.model.config.max_position_embeddings, :]
#             padded_masks = padded_masks[:, :self.model.config.max_position_embeddings]
#             positions = [min(pos, self.model.config.max_position_embeddings - 1) for pos in positions]
#         return padded_embeds, padded_masks, positions, batch

#     def _create_position_ids(self, batch_size, seq_len, desc_len, device):
#         return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)

#     def forward(self, batch: list[Example]):
#         device = self._dummy_tensor.device
#         inputs_embeds, attention_mask, positions, examples = self._prepare_batch(batch)
#         if inputs_embeds.size(0) == 0:
#             return torch.tensor(0.0, requires_grad=True, device=device)
#         batch_size, seq_len, _ = inputs_embeds.shape
#         token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
#         ex = examples[0]
#         desc_key = ex.description
#         if desc_key in self.desc_cache:
#             desc_emb = self.desc_cache[desc_key].to(device)
#         else:
#             toks = self.tokenizer(ex.description, return_tensors="pt").to(device)
#             desc_emb = self.embedding_layer(toks.input_ids).squeeze(0)
#             self.desc_cache[desc_key] = desc_emb.detach().cpu()
#         desc_len = desc_emb.size(0)
#         position_ids = self._create_position_ids(batch_size, seq_len, desc_len, device)
#         out = self.model(
#             inputs_embeds=inputs_embeds,
#             attention_mask=attention_mask,
#             token_type_ids=token_type_ids,
#             position_ids=position_ids,
#             return_dict=True
#         )
#         h = out.last_hidden_state
#         target_h = torch.stack([h[i, positions[i]] for i in range(len(positions))], dim=0)
#         pi, mu, sigma = self.reg_head(target_h)
#         prototype_cache = {}
#         losses = []
#         for i, ex in enumerate(examples):
#             feat = ex.features[ex.target_column_id]
#             if feat.dtype == 'real':
#                 y_t = torch.tensor(ex.target_row[ex.target_column_id], device=device).float()
#                 y_p = (pi[i] * mu[i]).sum(dim=-1)
#                 losses.append(F.mse_loss(y_p, y_t))
#             else:
#                 vec = self.class_head(target_h[i])
#                 cats = getattr(feat, 'categories', []) or []
#                 num_cats = len(cats)
#                 if feat.dtype != 'categorical' or num_cats == 0:
#                     logits = vec.new_zeros((1,))
#                 else:
#                     feat_id = id(feat)
#                     if feat_id not in prototype_cache:
#                         protos = torch.stack([
#                             torch.tensor(feat.categories_embedding[self.config.embedding_model][c], device=device)
#                             if not isinstance(feat.categories_embedding[self.config.embedding_model][c], torch.Tensor)
#                             else feat.categories_embedding[self.config.embedding_model][c].to(device)
#                             for c in cats
#                         ], dim=0)
#                         prototype_cache[feat_id] = protos
#                     else:
#                         protos = prototype_cache[feat_id]
#                     logits = vec.unsqueeze(0) @ protos.t()
#                     target_idx = int(ex.target_row[ex.target_column_id])
#                     target_idx = max(0, min(target_idx, num_cats-1))
#                     tgt = torch.tensor([target_idx], device=device)
#                     losses.append(F.cross_entropy(logits, tgt))
#         return torch.stack(losses).mean() if losses else torch.tensor(0.0, requires_grad=True, device=device)
import math
import torch
import numpy as np
import torch.nn as nn
import logging
import torch.nn.functional as F
from utils.data_utils import Example, Feature
from model.architectures.set_transformer import ISAB, PMA
from utils.transformer_utils import get_tokenizer, get_model
import random
logger = logging.getLogger(__name__)
from collections import OrderedDict
import weakref

class SmartCache:
    def __init__(self, max_size=10000, max_memory_gb=5.0):
        self.cache = OrderedDict()
        self.max_size = max_size
        self.max_memory_gb = max_memory_gb
        self.memory_usage = 0
        
    def get(self, key):
        if key in self.cache:
            self.cache.move_to_end(key)
            return self.cache[key]
        return None
    
    def put(self, key, value):
        if key in self.cache:
            self.cache.move_to_end(key)
            return
            
        # Estimate tensor memory
        if isinstance(value, torch.Tensor):
            tensor_memory = value.element_size() * value.nelement() / 1024**3
        else:
            tensor_memory = 0.001  # 1MB estimate for other objects
            
        # Check memory limit
        if self.memory_usage + tensor_memory > self.max_memory_gb:
            # Remove oldest items until we have space
            while self.cache and self.memory_usage + tensor_memory > self.max_memory_gb:
                self.cache.popitem(last=False)
                self.memory_usage *= 0.9  # Rough estimate
        
        self.cache[key] = value
        self.memory_usage += tensor_memory
        
        # Size limit check
        if len(self.cache) > self.max_size:
            self.cache.popitem(last=False)
class FeatureEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dtype_embed = nn.Embedding(2, config.model_dim)
        self.desc_proj = nn.Linear(config.embedding_dim, config.model_dim)
        self.embedding_cache = SmartCache(max_size=5000, max_memory_gb=2.0)  # 2GB for feature embeddings
        
    def forward(self, feature: Feature):
        device = next(self.parameters()).device
        
        cache_key = id(feature)
        cached = self.embedding_cache.get(cache_key)
        if cached is not None:
            return cached.to(device)  # Ensure on correct device
            
        # Compute embedding
        embedding = feature.description_embedding[self.config.embedding_model]
        if isinstance(embedding, torch.Tensor):
            desc = embedding.to(device, non_blocking=True)
        else:
            desc = torch.tensor(embedding, device=device, dtype=torch.float32)
        
        desc = self.desc_proj(desc)
        
        dtype_idx = 0 if feature.dtype == 'real' else 1
        dt = self.dtype_embed(torch.tensor(dtype_idx, device=device))
        out = desc + dt
        
        self.embedding_cache.put(cache_key, out.detach())
        return out
class FeatureValuePairEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.feature_embed = FeatureEmbedding(config)
        self.real_proj = nn.Linear(1, config.model_dim)
        self.missing_embed = nn.Embedding(1, config.model_dim)
        self.embed_proj = nn.Linear(config.model_dim * 2, config.model_dim)
        self.cat_value_cache = {}
        self.register_buffer('_zero_tensor', torch.zeros(1, dtype=torch.long))
    
    def forward(self, feature: Feature, value):
        device = next(self.parameters()).device
        
        f_emb = self.feature_embed(feature)

        if value is None or (isinstance(value, float) and math.isnan(value)):
            v_emb = self.missing_embed(self._zero_tensor).squeeze(0)
        elif feature.dtype == 'real':
            v_emb = self.real_proj(torch.tensor([[value]], device=device, dtype=torch.float32)).squeeze(0)
        else:
            # Categorical value - handle as string
            # cache_key = (id(feature), value)
            # if cache_key in self.cat_value_cache:
            #     v_emb = self.cat_value_cache[cache_key]
            # else:
                # FIXED: Look up string value in categories list
            if feature.categories:
                value_str = str(value)  # Ensure it's a string
                try:
                    idx = feature.categories.index(value_str)
                except ValueError:
                    # Value not found - use first category as default
                    logger.warning(f"Category '{value_str}' not found in {feature.name} categories")
                    idx = 0
                
                cat = feature.categories[idx]
                emb = feature.categories_embedding[self.config.embedding_model][cat]
                v_emb = torch.tensor(emb, device=device, dtype=torch.float32) if not isinstance(emb, torch.Tensor) else emb.to(device, non_blocking=True)
            else:
                v_emb = self.missing_embed(self._zero_tensor).squeeze(0)
            
            # self.cat_value_cache[cache_key] = v_emb.detach()

        combined = torch.cat([f_emb, v_emb], dim=-1)
        return self.embed_proj(combined)
    
from model.architectures.set_transformer import ISAB
class PEQRowEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.feature_value_embed = FeatureValuePairEmbedding(config)
        self.peq = ISAB(config.model_dim, config.model_dim, config.num_heads, config.num_inds)
        self.row_cache = SmartCache(max_size=20000, max_memory_gb=5.0)  # 5GB for rows

    # @torch.jit.script_method  # JIT compile for speed
    def forward(self, features: list[Feature], values: list[float]):
        device = next(self.parameters()).device
        
        # Create hashable cache key
        cache_key = hash((tuple(id(f) for f in features), 
                         tuple(v if v is not None else -999 for v in values)))
        
        cached = self.row_cache.get(cache_key)
        if cached is not None:
            return cached.to(device)

        # Batch compute embeddings
        with torch.cuda.amp.autocast():  # Use automatic mixed precision
            embs = [self.feature_value_embed(f, v) for f, v in zip(features, values)]
            
            if not embs:
                result = torch.zeros((0, self.config.model_dim), device=device)
            else:
                x = torch.stack(embs, dim=0).unsqueeze(0)
                out = self.peq(x)
                result = out.squeeze(0)

        self.row_cache.put(cache_key, result.detach().cpu())  # Store on CPU
        return result

class MDNRegressionHead(nn.Module):
    def __init__(self, hidden_dim, num_mixtures):
        super().__init__()
        self.num_mixtures = num_mixtures
        self.pi_layer = nn.Linear(hidden_dim, num_mixtures)
        self.mu_layer = nn.Linear(hidden_dim, num_mixtures)
        self.sigma_layer = nn.Linear(hidden_dim, num_mixtures)

    def forward(self, h):
        pi = F.softmax(self.pi_layer(h), dim=-1)
        mu = self.mu_layer(h)
        sigma = torch.exp(self.sigma_layer(h))
        return pi, mu, sigma

class PermutationEquivariantClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding_dim = config.embedding_dim
        self.input_proj = nn.Linear(config.model_dim, self.embedding_dim)
        
        num_layers = getattr(config, 'num_equivariant_layers', 2)
        first_block = EquivariantBlock(self.embedding_dim * 2, self.embedding_dim)
        other_blocks = [EquivariantBlock(self.embedding_dim, self.embedding_dim) for _ in range(num_layers - 1)]
        self.equivariant_layers = nn.ModuleList([first_block] + other_blocks)
        self.logit_proj = nn.Linear(self.embedding_dim, 1)
        
    def forward(self, feature_rep, class_embeddings):
        feature_rep = self.input_proj(feature_rep)
        
        if feature_rep.dim() == 1:
            feature_rep = feature_rep.unsqueeze(0)
        
        batch_size = feature_rep.size(0)
        num_classes = class_embeddings.size(0)
        
        expanded_feature = feature_rep.unsqueeze(1).expand(-1, num_classes, -1)
        expanded_classes = class_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        combined = torch.cat([expanded_feature, expanded_classes], dim=-1)
        
        x = combined
        for layer in self.equivariant_layers:
            x = layer(x)
        
        logits = self.logit_proj(x).squeeze(-1)
        return logits

class EquivariantBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.LayerNorm(output_dim),
            nn.ReLU()
        )
        self.attention = nn.MultiheadAttention(
            embed_dim=output_dim,
            num_heads=4,
            batch_first=True
        )
        self.norm = nn.LayerNorm(output_dim)
        
    def forward(self, x):
        x = self.phi(x)
        attn_out, _ = self.attention(x, x, x)
        x = self.norm(attn_out + x)
        return x

class PEQBertTransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.isab = ISAB(
            embed_dim=config.model_dim,
            num_inds=config.num_inds,
            num_heads=config.num_heads,
        )
        self.pma = PMA(
            embed_dim=config.model_dim,
            num_heads=config.num_heads,
            num_seeds=1,
        )

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        x = self.isab(hidden_states)
        x = self.pma(x)
        seq_len = hidden_states.size(1)
        peq_out = x.expand(-1, seq_len, -1)
        if output_attentions:
            return (peq_out, None)
        return (peq_out,)

# class FSLModel(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.row_embed = PEQRowEmbedding(config)
#         self.tokenizer = get_tokenizer(config.model_name)
#         self.model = get_model(config.model_name)
#         self.max_desc_cache_size = 100
#         self.max_example_cache_size = 5000
#         self.dropout = nn.Dropout(0.1)
#         self.target_norm = nn.LayerNorm(config.model_dim)
#         self.class_norm = nn.LayerNorm(config.model_dim)
#         # if getattr(config, 'use_peq_bert', False):
#         # self._replace_attention_layers()
            
#         self.temperature = getattr(config, 'temperature', 1.0)
#         self.class_head = nn.Linear(config.model_dim, config.model_dim)
#         self.reg_head = MDNRegressionHead(config.model_dim, config.num_mixtures)
#         self.set_pool = PMA(
#             config.model_dim,
#             config.num_heads,
#             1
#         )
#         self.binary_prototypes = nn.Parameter(torch.randn(2, config.model_dim))
#         nn.init.normal_(self.binary_prototypes, mean=0.0, std=0.02)
#         # Pre-allocate buffers
#         max_batch_size = 256
#         max_seq_len = 512
#         self.register_buffer('_token_type_ids', torch.zeros(max_batch_size, max_seq_len, dtype=torch.long))
#         self.register_buffer('_position_ids', torch.arange(max_seq_len).unsqueeze(0).expand(max_batch_size, -1))
        
#         # Caches
#         self.desc_cache = {}
#         self.example_cache = {}
#         self._embedding_layer = None
#         if hasattr(torch, 'compile'):
#                 self.row_embed = torch.compile(self.row_embed)
#                 self.class_head = torch.compile(self.class_head)
#                 self.reg_head = torch.compile(self.reg_head)
#     # def _replace_attention_layers(self):
#     #     for layer in self.model.encoder.layer:
#     #         layer.attention.self = PEQBertTransformerLayer(self.config)
#     def clear_all_caches(self):
#         """Aggressively clear all caches"""
#         # Clear all module caches
#         self.desc_cache.clear()
#         self.example_cache.clear()
#         self.row_embed.row_cache.clear()
#         self.row_embed.feature_value_embed.cat_value_cache.clear()
#         self.row_embed.feature_embed.embedding_cache.clear()
        
#         # Force garbage collection
#         import gc
#         for _ in range(3):
#             gc.collect()
        
#         # Clear CUDA cache
#         if torch.cuda.is_available():
#             torch.cuda.empty_cache()
#             torch.cuda.synchronize()
#     @property
#     def embedding_layer(self):
#         if self._embedding_layer is None:
#             self._embedding_layer = self.model.get_input_embeddings()
#         return self._embedding_layer

#     def _prepare_batch(self, batch: list[Example]):
#         device = next(self.parameters()).device
#         processed = [self._prepare_example(ex) for ex in batch]
        
#         # Now unpack 4 values
#         embeds, masks, position_ids_list, positions = zip(*processed)
        
#         max_len = max(x.size(1) for x in embeds)
#         max_pos_emb = self.model.config.max_position_embeddings
        
#         if max_len > max_pos_emb:
#             max_len = max_pos_emb
#             positions = [min(pos, max_pos_emb - 1) for pos in positions]
            
#         # Pre-allocate tensors
#         padded_embeds = torch.zeros((len(batch), max_len, embeds[0].size(2)), device=device)
#         padded_masks = torch.zeros((len(batch), max_len), device=device)
#         padded_position_ids = torch.zeros((len(batch), max_len), dtype=torch.long, device=device)
        
#         for i, (emb, mask, pos_ids) in enumerate(zip(embeds, masks, position_ids_list)):
#             seq_len = min(emb.size(1), max_len)
#             padded_embeds[i, :seq_len] = emb.squeeze(0)[:seq_len]
#             padded_masks[i, :seq_len] = mask.squeeze(0)[:seq_len]
#             padded_position_ids[i, :seq_len] = pos_ids.squeeze(0)[:seq_len]
            
#         return padded_embeds, padded_masks, padded_position_ids, positions, batch
#     # Simplest implementation - just use 0, 1, 2 as position IDs

#     def _prepare_example(self, example: Example, feature_dropout_rate=0):
#         device = next(self.parameters()).device
        
#         # Description embedding
#         desc_key = example.description
#         if desc_key in self.desc_cache:
#             desc_emb = self.desc_cache[desc_key]
#         else:
#             toks = self.tokenizer(example.description, return_tensors="pt", max_length=64, truncation=True)
#             input_ids = toks.input_ids.to(device, non_blocking=True)
#             desc_emb = self.embedding_layer(input_ids).squeeze(0)
#             self.desc_cache[desc_key] = desc_emb.detach()
        
#         # Position IDs - super simple!
#         desc_len = desc_emb.size(0)
#         # desc_positions = torch.zeros(desc_len, dtype=torch.long, device=device)  # All 0s
        
#         # Few-shot embeddings
#         fs_embs = []
#         total_fs_len = 0
        
#         for i, row in enumerate(example.fewshot_rows):
#             try:
#                 # Feature selection
#                 feature_mask = torch.rand(len(example.features)) > feature_dropout_rate
#                 feature_mask = feature_mask.numpy()
#                 if not feature_mask.any():
#                     feature_mask[random.randrange(len(example.features))] = True
                
#                 selected_features = [f for f, m in zip(example.features, feature_mask) if m]
#                 selected_values = [v for v, m in zip(row, feature_mask) if m]
                
#                 emb = self.row_embed(selected_features, selected_values)
#                 fs_embs.append(emb)
#                 total_fs_len += emb.size(0)
                
#             except Exception as e:
#                 logger.warning(f"Error in few-shot row {i}: {e}")
#                 continue
        
#         # All few-shot positions = 1 (permutation invariant!)
#         # fs_positions = torch.ones(total_fs_len, dtype=torch.long, device=device)
        
#         # Target embedding
#         excluded = set(example.missing_column_ids + [example.target_column_id])
        
#         # Select features for target
#         feature_mask = torch.rand(len(example.features)) > feature_dropout_rate
#         feature_mask = feature_mask.numpy()
#         feature_mask[example.target_column_id] = True
#         for idx in excluded:
#             if idx != example.target_column_id and idx < len(feature_mask):
#                 feature_mask[idx] = False
        
#         selected_indices = np.where(feature_mask)[0].tolist()
#         selected_features = [example.features[i] for i in selected_indices]
#         selected_values = [None if i == example.target_column_id else example.target_row[i] 
#                         for i in selected_indices]
        
#         try:
#             tgt_emb = self.row_embed(selected_features, selected_values)
#         except Exception as e:
#             logger.warning(f"Error in target row: {e}")
#             tgt_emb = torch.zeros((len(selected_features), self.config.model_dim), device=device)
        
#         # All target positions = 2
#         tgt_len = tgt_emb.size(0)
#         # tgt_positions = torch.full((tgt_len,), 2, dtype=torch.long, device=device)
#         desc_positions = torch.arange(desc_len, device=device)  # 0, 1, 2, ...
#         fs_positions = torch.arange(desc_len, desc_len + total_fs_len, device=device)
#         tgt_positions = torch.arange(desc_len + total_fs_len, desc_len + total_fs_len + tgt_len, device=device)
#         # Find target position
#         target_pos_in_selected = selected_indices.index(example.target_column_id)
        
#         # Concatenate everything
#         seq = torch.cat([desc_emb] + fs_embs + [tgt_emb], dim=0)
#         position_ids = torch.cat([desc_positions, fs_positions, tgt_positions], dim=0)
#         attn_mask = torch.ones(1, seq.size(0), device=device)
        
#         # Absolute target position
#         target_pos = desc_len + total_fs_len + target_pos_in_selected
        
#         return (seq.unsqueeze(0).detach(), attn_mask.detach(), position_ids.unsqueeze(0).detach(), target_pos)

#     def forward(self, batch: list[Example]):
#         device = next(self.parameters()).device
#         # prepare raw embeddings from BERT input embeddings
#         inputs_embeds, attention_mask, position_ids, positions, examples = self._prepare_batch(batch)
#         bsz, seq_len, _ = inputs_embeds.size()
#         # slice buffers
#         token_type_ids = self._token_type_ids[:bsz, :seq_len]
#         # position_ids   = self._position_ids[:bsz, :seq_len]
#         # token_type_ids = self._token_type_ids[:bsz, :seq_len]
#         # run the BERT encoder
#                 # run BERT
#         out = self.model(
#             inputs_embeds=inputs_embeds,
#             attention_mask=attention_mask,
#             token_type_ids=token_type_ids,
#             position_ids=position_ids,
#             return_dict=True
#         )
#         h = out.last_hidden_state
#         h = self.dropout(h)
#         target_h = torch.stack([h[i, positions[i]] for i in range(len(positions))], dim=0)
#         target_h = self.target_norm(target_h)   
#         # target_h = torch.stack([h[i, positions[i]] for i in range(len(positions))], dim=0)

#         losses = []
#         prototype_cache = {}
#         real_idxs, cat_idxs = [], []
#         for i, ex in enumerate(examples):
#             if ex.features[ex.target_column_id].dtype == 'real':
#                 real_idxs.append(i)
#             else:
#                 cat_idxs.append(i)

#         # ── real‐valued → MDN regression loss ─────────────────────────────
#         if real_idxs:
#             targets = torch.tensor(
#                 [examples[i].target_row[examples[i].target_column_id] for i in real_idxs],
#                 device=h.device, dtype=torch.float32
#             )
#             pi, mu, sigma = self.reg_head(target_h[real_idxs])
#             pred = (pi * mu).sum(-1)
#             losses.append(F.mse_loss(pred, targets))

#         # ── categorical → prototype‐based cross‐entropy ─────────────────
#         for i in cat_idxs:
#             ex = examples[i]
#             vec = self.class_head(target_h[i])  # [model_dim]
#             cats = ex.features[ex.target_column_id].categories or []

#             # build class‐prototypes (one stack per feature)
#             fid = id(ex.features[ex.target_column_id])
#             if fid not in prototype_cache:
#                 protos = []
#                 for c in cats:
#                     emb = ex.features[ex.target_column_id].categories_embedding[self.config.embedding_model][c]
#                     protos.append( emb.to(h.device) if isinstance(emb, torch.Tensor)
#                                   else torch.tensor(emb, device=h.device) )
#                 prototype_cache[fid] = torch.stack(protos, dim=0)  # [n_cats, model_dim]

#             protos = prototype_cache[fid]       # multi‐class
#             logits = vec.unsqueeze(0) @ protos.t()

#             # ── if it’s actually a binary target disguised as real: use learned `self.binary_prototypes` ─
#             if len(cats) == 0:
#                 # fallback: binary from learned prototypes
#                 protos2 = F.normalize(self.binary_prototypes, dim=1)  # [2, D]
#                 logits = vec.unsqueeze(0) @ protos2.t()
#                 # decide true label by thresholding your val ≥ median stored somewhere…
#                 true_bin = 1 if float(ex.target_row[ex.target_column_id]) >= ex.median else 0
#                 losses.append(F.cross_entropy(logits, torch.tensor([true_bin], device=h.device)))
#             else:
#                 # normal multi‐class
#                 # find target index by string match…
#                 tv = str(ex.target_row[ex.target_column_id])
#                 idx = cats.index(tv) if tv in cats else 0
#                 losses.append(F.cross_entropy(logits, torch.tensor([idx], device=h.device)))

#         return torch.stack(losses).mean() if losses else torch.tensor(0.0, device=h.device)
class FSLModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.row_embed = PEQRowEmbedding(config)
        self.tokenizer = get_tokenizer(config.model_name)
        self.model = get_model(config.model_name)
        
        # Get max position embeddings from model config
        self.max_position_embeddings = self.model.config.max_position_embeddings
        
        self.max_desc_cache_size = 100
        self.max_example_cache_size = 5000
        self.dropout = nn.Dropout(0.1)
        self.target_norm = nn.LayerNorm(config.model_dim)
        self.class_norm = nn.LayerNorm(config.model_dim)
        
        # Add segment embeddings to distinguish parts
        self.segment_embed = nn.Embedding(3, config.model_dim)  # 0=desc, 1=fewshot, 2=target
        
        self.temperature = getattr(config, 'temperature', 1.0)
        self.class_head = nn.Linear(config.model_dim, config.model_dim)
        self.reg_head = MDNRegressionHead(config.model_dim, config.num_mixtures)
        self.set_pool = PMA(config.model_dim, config.num_heads, 1)
        
        self.binary_prototypes = nn.Parameter(torch.randn(2, config.model_dim))
        nn.init.normal_(self.binary_prototypes, mean=0.0, std=0.02)
        
        # Pre-allocate buffers
        max_batch_size = 256
        max_seq_len = 512
        self.register_buffer('_token_type_ids', torch.zeros(max_batch_size, max_seq_len, dtype=torch.long))
        self.register_buffer('_position_ids', torch.arange(max_seq_len).unsqueeze(0).expand(max_batch_size, -1))
        
        # Caches
        self.desc_cache = {}
        self.example_cache = {}
        self._embedding_layer = None
        
        # DON'T enable gradient checkpointing with DDP - it causes conflicts
        # DON'T use torch.compile with DDP - it causes recompilation issues
    
    @property
    def embedding_layer(self):
        """Get the embedding layer from the model"""
        if self._embedding_layer is None:
            self._embedding_layer = self.model.get_input_embeddings()
        return self._embedding_layer
    
    def clear_all_caches(self):
        """Aggressively clear all caches"""
        # Clear all module caches
        self.desc_cache.clear()
        self.example_cache.clear()
        if hasattr(self.row_embed, 'row_cache'):
            self.row_embed.row_cache.cache.clear()
            self.row_embed.row_cache.memory_usage = 0
        if hasattr(self.row_embed.feature_value_embed, 'cat_value_cache'):
            self.row_embed.feature_value_embed.cat_value_cache.clear()
        if hasattr(self.row_embed.feature_embed, 'embedding_cache'):
            self.row_embed.feature_embed.embedding_cache.cache.clear()
            self.row_embed.feature_embed.embedding_cache.memory_usage = 0
        
        # Force garbage collection
        import gc
        for _ in range(3):
            gc.collect()
        
        # Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

    def _prepare_example(self, example: Example, feature_dropout_rate=0):
        device = next(self.parameters()).device
        
        # Add median to example if available
        if hasattr(example, 'dataset') and hasattr(example.dataset, 'median'):
            example.median = example.dataset.median
        elif not hasattr(example, 'median'):
            example.median = 0.0  # Default median
        
        # Description embedding
        desc_key = example.description
        if desc_key in self.desc_cache:
            desc_emb = self.desc_cache[desc_key]
        else:
            toks = self.tokenizer(example.description, return_tensors="pt", max_length=64, truncation=True)
            input_ids = toks.input_ids.to(device, non_blocking=True)
            desc_emb = self.embedding_layer(input_ids).squeeze(0)
            self.desc_cache[desc_key] = desc_emb.detach()
        
        desc_len = desc_emb.size(0)
        
        # Few-shot embeddings
        fs_embs = []
        total_fs_len = 0
        
        for i, row in enumerate(example.fewshot_rows):
            try:
                # Feature selection with dropout
                if self.training and feature_dropout_rate > 0:
                    feature_mask = torch.rand(len(example.features)) > feature_dropout_rate
                    feature_mask = feature_mask.numpy()
                    if not feature_mask.any():
                        feature_mask[random.randrange(len(example.features))] = True
                else:
                    feature_mask = np.ones(len(example.features), dtype=bool)
                
                selected_features = [f for f, m in zip(example.features, feature_mask) if m]
                selected_values = [v for v, m in zip(row, feature_mask) if m]
                
                emb = self.row_embed(selected_features, selected_values)
                fs_embs.append(emb)
                total_fs_len += emb.size(0)
                
            except Exception as e:
                logger.warning(f"Error in few-shot row {i}: {e}")
                continue
        
        # Target embedding
        excluded = set(example.missing_column_ids + [example.target_column_id])
        
        # Select features for target
        feature_mask = np.ones(len(example.features), dtype=bool)
        feature_mask[example.target_column_id] = True
        for idx in excluded:
            if idx != example.target_column_id and idx < len(feature_mask):
                feature_mask[idx] = False
        
        selected_indices = np.where(feature_mask)[0].tolist()
        selected_features = [example.features[i] for i in selected_indices]
        selected_values = [None if i == example.target_column_id else example.target_row[i] 
                          for i in selected_indices]
        
        try:
            tgt_emb = self.row_embed(selected_features, selected_values)
        except Exception as e:
            logger.warning(f"Error in target row: {e}")
            tgt_emb = torch.zeros((len(selected_features), self.config.model_dim), device=device)
        
        tgt_len = tgt_emb.size(0)
        
        # Create position IDs that don't exceed max_position_embeddings
        total_len = desc_len + total_fs_len + tgt_len
        
        if total_len > self.max_position_embeddings:
            # Truncate or use cyclic position IDs
            position_ids = torch.arange(total_len, device=device) % self.max_position_embeddings
        else:
            position_ids = torch.arange(total_len, device=device)
        
        # Create segment IDs
        # segment_ids = torch.cat([
        #     torch.zeros(desc_len, dtype=torch.long, device=device),      # description = 0
        #     torch.ones(total_fs_len, dtype=torch.long, device=device),   # few-shot = 1
        #     torch.full((tgt_len,), 2, dtype=torch.long, device=device)   # target = 2
        # ])
        segment_ids = torch.cat([
            torch.zeros(desc_len, dtype=torch.long, device=device),      # description = 0
            torch.ones(total_fs_len, dtype=torch.long, device=device),   # few-shot = 1
            torch.ones(tgt_len, dtype=torch.long, device=device)         # target = 1 (same as few-shot)
        ])
        # segment_ids = torch.cat([
        #     torch.arange(desc_len, device=device), 
        #     torch.arange(desc_len, desc_len + total_fs_len, device=device),
        #     torch.arange(desc_len + total_fs_len, desc_len + total_fs_len + tgt_len, device=device)
        # ])
        # Find target position
        target_pos_in_selected = selected_indices.index(example.target_column_id)
        
        # Concatenate everything
        seq = torch.cat([desc_emb] + fs_embs + [tgt_emb], dim=0)
        
        # Add segment embeddings
        segment_embeds = self.segment_embed(segment_ids)
        seq = seq + segment_embeds
        
        attn_mask = torch.ones(1, seq.size(0), device=device)
        
        # Absolute target position
        target_pos = desc_len + total_fs_len + target_pos_in_selected
        
        return (seq.unsqueeze(0).detach(), attn_mask.detach(), 
                position_ids.unsqueeze(0).detach(), target_pos, segment_ids)

    def _prepare_batch(self, batch: list[Example]):
        device = next(self.parameters()).device
        processed = [self._prepare_example(ex) for ex in batch]
        
        # Unpack 5 values now (added segment_ids)
        embeds, masks, position_ids_list, positions, segment_ids_list = zip(*processed)
        
        max_len = max(x.size(1) for x in embeds)
        max_pos_emb = self.max_position_embeddings
        
        if max_len > max_pos_emb:
            max_len = max_pos_emb
            positions = [min(pos, max_pos_emb - 1) for pos in positions]
        
        # Pre-allocate tensors
        padded_embeds = torch.zeros((len(batch), max_len, embeds[0].size(2)), device=device)
        padded_masks = torch.zeros((len(batch), max_len), device=device)
        padded_position_ids = torch.zeros((len(batch), max_len), dtype=torch.long, device=device)
        padded_segment_ids = torch.zeros((len(batch), max_len), dtype=torch.long, device=device)
        
        for i, (emb, mask, pos_ids, seg_ids) in enumerate(zip(embeds, masks, position_ids_list, segment_ids_list)):
            seq_len = min(emb.size(1), max_len)
            padded_embeds[i, :seq_len] = emb.squeeze(0)[:seq_len]
            padded_masks[i, :seq_len] = mask.squeeze(0)[:seq_len]
            # Clamp position IDs to max allowed
            padded_position_ids[i, :seq_len] = torch.clamp(pos_ids.squeeze(0)[:seq_len], max=max_pos_emb-1)
            if seq_len <= len(seg_ids):
                padded_segment_ids[i, :seq_len] = seg_ids[:seq_len]
            else:
                # Handle case where seg_ids is shorter than seq_len
                padded_segment_ids[i, :len(seg_ids)] = seg_ids
                # Fill rest with last segment type (target = 2)
                padded_segment_ids[i, len(seg_ids):seq_len] = 2
        
        return padded_embeds, padded_masks, padded_position_ids, positions, batch, padded_segment_ids

    def forward(self, batch: list[Example]):
        device = next(self.parameters()).device
        
        # Add warmup scaling based on training steps
        if self.training and hasattr(self, 'global_step'):
            warmup_steps = 1000
            if self.global_step < warmup_steps:
                warmup_scale = self.global_step / warmup_steps
            else:
                warmup_scale = 1.0
        else:
            warmup_scale = 1.0
        
        # Prepare batch
        inputs_embeds, attention_mask, position_ids, positions, examples, segment_ids = self._prepare_batch(batch)
        bsz, seq_len, _ = inputs_embeds.size()
        
        # Get token type IDs
        token_type_ids = self._token_type_ids[:bsz, :seq_len]
        
        # Run BERT encoder
        out = self.model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            return_dict=True
        )
        
        h = out.last_hidden_state
        h = self.dropout(h)
        
        # Extract target hidden states
        target_h = torch.stack([h[i, positions[i]] for i in range(len(positions))], dim=0)
        target_h = self.target_norm(target_h)
        
        losses = []
        prototype_cache = {}
        
        for i, ex in enumerate(examples):
            vec = self.class_head(target_h[i])
            vec = self.class_norm(vec)
            
            feat = ex.features[ex.target_column_id]
            cats = feat.categories or []
            
            if feat.dtype == 'real' or len(cats) == 0:
                # BINARY CLASSIFICATION for all real-valued features
                protos2 = F.normalize(self.binary_prototypes, dim=1)
                logits = vec.unsqueeze(0) @ protos2.t()
                logits = logits / self.temperature
                
                # Get median from the dataset
                median = getattr(ex, 'median', 0.0)
                true_val = float(ex.target_row[ex.target_column_id])
                true_bin = int(true_val >= median)
                
                loss = F.cross_entropy(logits, torch.tensor([true_bin], device=device))
                loss = torch.clamp(loss, max=5.0)
                # loss = loss / 0.693
                losses.append(loss)
            else:
                # Multi-class classification for categorical
                fid = id(feat)
                if fid not in prototype_cache:
                    protos = []
                    for c in cats:
                        emb = feat.categories_embedding[self.config.embedding_model][c]
                        proto_emb = emb.to(device) if isinstance(emb, torch.Tensor) else torch.tensor(emb, device=device)
                        protos.append(F.normalize(proto_emb, dim=0))
                    prototype_cache[fid] = torch.stack(protos, dim=0)
                
                protos = prototype_cache[fid]
                logits = vec.unsqueeze(0) @ protos.t()
                logits = logits / self.temperature
                
                tv = str(ex.target_row[ex.target_column_id])
                try:
                    idx = cats.index(tv)
                except ValueError:
                    idx = 0
                    logger.warning(f"Target value '{tv}' not found in categories")
                
                loss = F.cross_entropy(logits, torch.tensor([idx], device=device))
                loss = torch.clamp(loss, max=5.0)
                # loss = loss / (math.log(len(cats)) + 1e-6)
                losses.append(loss)
        
        # Clip and average losses
        if losses:
            losses_tensor = torch.stack(losses)
            
            if len(losses_tensor) > 1:
                max_loss = torch.quantile(losses_tensor, 0.95)
                losses_tensor = torch.clamp(losses_tensor, max=max_loss.item())
            return losses_tensor.mean() * warmup_scale
        # if losses:
        #     losses_tensor = torch.stack(losses)
        #     final_loss = losses_tensor.mean()
            
        #     # Add noise to prevent zero loss
        #     noise = torch.rand(1, device=device) * 0.001
        #     final_loss = final_loss + noise + 0.001  # Minimum loss of 0.001
            
        #     return final_loss
        else:
            return torch.tensor(0.0, device=device, requires_grad=True) * warmup_scale