"""
KV Predictor

This module contains the MLP network that predicts delta key and delta value
for correcting anchor token representations.
"""

import torch
import torch.nn as nn


class DeltaKVPredictor(nn.Module):
    """
    MLP network for predicting delta K and delta V to correct anchor token KV states.
    
    The network takes as input:
        - hidden_state: Current token's hidden state at the trigger position
        - anchor_k: Original key of the anchor token
        - anchor_v: Original value of the anchor token
    
    And outputs:
        - delta_k: Correction to add to anchor token's key
        - delta_v: Correction to add to anchor token's value
    
    Architecture:
        Input (hidden_state + anchor_k + anchor_v) -> FC1 -> Activation ->
        FC2 -> Activation -> Split into (delta_k, delta_v) branches
    """
    
    def __init__(
        self, 
        hidden_dim: int, 
        kv_dim: int, 
        mlp_hidden_dim: int = None, 
        activation: str = 'relu',
        dropout: float = 0.1
    ):
        """
        Initialize the Delta KV Predictor.
        
        Args:
            hidden_dim: Dimension of hidden state
            kv_dim: Dimension of key and value vectors
            mlp_hidden_dim: Hidden dimension of MLP. Defaults to (hidden_dim + 2*kv_dim)
            activation: Activation function ('relu', 'gelu', or 'silu')
            dropout: Dropout rate for regularization
        """
        super(DeltaKVPredictor, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.kv_dim = kv_dim
        
        # Input dimension = hidden_dim + kv_dim + kv_dim (hidden_state + anchor_k + anchor_v)
        input_dim = hidden_dim + 2 * kv_dim
        
        # Default hidden dimension
        if mlp_hidden_dim is None:
            mlp_hidden_dim = input_dim
        
        # Select activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}. Choose from ['relu', 'gelu', 'silu']")
        
        # MLP architecture
        # Layer 1: Input -> Hidden
        self.fc1 = nn.Linear(input_dim, mlp_hidden_dim)
        self.dropout1 = nn.Dropout(dropout)
        
        # Layer 2: Hidden -> Hidden (increases model capacity)
        self.fc2 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.dropout2 = nn.Dropout(dropout)
        
        # Output branches
        # Delta K branch
        self.fc_delta_k = nn.Linear(mlp_hidden_dim, kv_dim)
        
        # Delta V branch
        self.fc_delta_v = nn.Linear(mlp_hidden_dim, kv_dim)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize network weights using Xavier initialization."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        
    def forward(self, hidden_state: torch.Tensor, anchor_k: torch.Tensor, anchor_v: torch.Tensor):
        """
        Forward pass to predict delta K and delta V.
        
        Args:
            hidden_state: Current token's hidden state, shape: (batch_size, hidden_dim)
            anchor_k: Anchor token's key, shape: (batch_size, kv_dim)
            anchor_v: Anchor token's value, shape: (batch_size, kv_dim)
        
        Returns:
            delta_k: Predicted correction for anchor key, shape: (batch_size, kv_dim)
            delta_v: Predicted correction for anchor value, shape: (batch_size, kv_dim)
        """
        # Concatenate all inputs
        x = torch.cat([hidden_state, anchor_k, anchor_v], dim=-1)  # (batch_size, input_dim)
        
        # Layer 1 + activation + dropout
        x = self.fc1(x)  # (batch_size, mlp_hidden_dim)
        x = self.activation(x)
        x = self.dropout1(x)
        
        # Layer 2 + activation + dropout
        x = self.fc2(x)  # (batch_size, mlp_hidden_dim)
        x = self.activation(x)
        x = self.dropout2(x)
        
        # Predict delta K
        delta_k = self.fc_delta_k(x)  # (batch_size, kv_dim)
        
        # Predict delta V
        delta_v = self.fc_delta_v(x)  # (batch_size, kv_dim)
        
        return delta_k, delta_v
    
    def predict_corrected_kv(
        self, 
        hidden_state: torch.Tensor, 
        anchor_k: torch.Tensor, 
        anchor_v: torch.Tensor
    ):
        """
        Predict corrected key and value by adding deltas to original anchor KV.
        
        Args:
            hidden_state: Current token's hidden state, shape: (batch_size, hidden_dim)
            anchor_k: Anchor token's original key, shape: (batch_size, kv_dim)
            anchor_v: Anchor token's original value, shape: (batch_size, kv_dim)
        
        Returns:
            corrected_k: Corrected key (anchor_k + delta_k), shape: (batch_size, kv_dim)
            corrected_v: Corrected value (anchor_v + delta_v), shape: (batch_size, kv_dim)
        """
        delta_k, delta_v = self.forward(hidden_state, anchor_k, anchor_v)
        
        corrected_k = anchor_k + delta_k
        corrected_v = anchor_v + delta_v
        
        return corrected_k, corrected_v


# Example usage and validation
if __name__ == "__main__":
    print("="*60)
    print("Delta KV Predictor - Example Usage")
    print("="*60)
    
    # Set parameters
    batch_size = 4
    hidden_dim = 768  # e.g., BERT hidden size
    kv_dim = 64  # Key and value dimension
    
    # Create model
    print(f"\nCreating model with:")
    print(f"  - Hidden dim: {hidden_dim}")
    print(f"  - KV dim: {kv_dim}")
    print(f"  - Activation: relu")
    
    model = DeltaKVPredictor(
        hidden_dim=hidden_dim,
        kv_dim=kv_dim,
        mlp_hidden_dim=512,
        activation='relu'
    )
    
    # Create example inputs
    hidden_state = torch.randn(batch_size, hidden_dim)
    anchor_k = torch.randn(batch_size, kv_dim)
    anchor_v = torch.randn(batch_size, kv_dim)
    
    print(f"\nInput shapes:")
    print(f"  - hidden_state: {hidden_state.shape}")
    print(f"  - anchor_k: {anchor_k.shape}")
    print(f"  - anchor_v: {anchor_v.shape}")
    
    # Forward pass
    delta_k, delta_v = model(hidden_state, anchor_k, anchor_v)
    
    print(f"\nOutput shapes:")
    print(f"  - delta_k: {delta_k.shape}")
    print(f"  - delta_v: {delta_v.shape}")
    
    # Validate dimensions
    assert delta_k.shape == anchor_k.shape, "delta_k shape must match anchor_k shape"
    assert delta_v.shape == anchor_v.shape, "delta_v shape must match anchor_v shape"
    print(f"\n✓ All dimension checks passed!")
    
    # Test corrected KV prediction
    corrected_k, corrected_v = model.predict_corrected_kv(hidden_state, anchor_k, anchor_v)
    print(f"\nCorrected KV shapes:")
    print(f"  - corrected_k: {corrected_k.shape}")
    print(f"  - corrected_v: {corrected_v.shape}")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel parameters:")
    print(f"  - Total: {total_params:,}")
    print(f"  - Trainable: {trainable_params:,}")
    
    print(f"\n{'='*60}")
    print("Example completed successfully!")
    print("="*60)
