import math
import torch
import torch.nn as nn

class IncrementalMLP(nn.Module):
    def __init__(self, input_size, num_classes, use_attention=True, window_size=10):
        super().__init__()
        self.window_size = window_size
        self.use_attention = use_attention
        self.input_size = input_size
        self.hidden_size = 256
        self.scale = 1.0 / math.sqrt(256)

        # Attention components
        self.query = nn.Linear(input_size, 256)
        self.key = nn.Linear(256, 256)
        
        total_input_size = input_size + (256 if use_attention else 0)

        self.feature_extractor = nn.Sequential(
            nn.Linear(total_input_size, 512),
            nn.ReLU(),
            nn.Linear(512, self.hidden_size),
            nn.ReLU(),
        )
        
        self.classifier = nn.Linear(self.hidden_size, num_classes)

    def compute_context(self, x, prev_features):
        """Vectorized context computation"""
        if not prev_features or self.window_size == 0:
            return torch.zeros(x.size(0), 256, device=x.device)
        
        # Batch processing of previous features
        stacked_prev = torch.stack(prev_features, dim=1)  # [B, T, 256]
        keys = self.key(stacked_prev)                     # [B, T, 256]
        query = self.query(x).unsqueeze(1)                # [B, 1, 256]
        
        # Efficient attention calculation
        energy = torch.bmm(keys, query.transpose(1, 2)).squeeze(-1)  # [B, T]
        energy = energy * self.scale                        # <--- add scaling
        alpha = torch.softmax(energy, dim=1)
        return torch.bmm(alpha.unsqueeze(1), keys).squeeze(1)  # [B, 256]

    def forward(self, x, prev_features=None):
        context = torch.zeros(x.size(0), 256, device=x.device)
        if self.use_attention and prev_features:
            context = self.compute_context(x, prev_features)
        
        x = torch.cat([x, context], dim=1) if self.use_attention else x
        features = self.feature_extractor(x)
        return self.classifier(features), features