"""
SABR (Smart Adversarial Boundary Regularization) loss implementation.

SABR is a certified training method that uses smaller unsound IBP boxes
around adversarial examples to reduce approximation errors during bound
propagation, leading to higher standard and certified accuracy.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, Dict
from .bounds import BoundComputation


class SABRLoss(nn.Module):
    """
    SABR loss implementation for certified training.
    
    SABR combines adversarial training with bound-based certification
    by using smaller IBP boxes around adversarial examples.
    """
    
    def __init__(self, epsilon: float, alpha: float = 0.5, 
                 num_steps: int = 10, step_size: float = 0.01,
                 beta: float = 0.1):
        """
        Args:
            epsilon: L∞ perturbation budget for adversarial examples
            alpha: Weight balancing adversarial and bound losses  
            num_steps: Number of PGD steps for adversarial examples
            step_size: Step size for PGD
            beta: IBP box size factor (beta * epsilon)
        """
        super().__init__()
        self.epsilon = epsilon
        self.alpha = alpha
        self.num_steps = num_steps
        self.step_size = step_size
        self.beta = beta
        self.bound_computer = BoundComputation()
    
    def forward(self, model: nn.Module, inputs: torch.Tensor, 
                targets: torch.Tensor) -> torch.Tensor:
        """
        Compute SABR loss.
        
        Args:
            model: Neural network model
            inputs: Input batch
            targets: Target labels
            
        Returns:
            SABR loss value
        """
        # Generate adversarial examples
        adv_inputs = self._generate_adversarial_examples(
            model, inputs, targets
        )
        
        # Compute adversarial loss
        adv_outputs = model(adv_inputs)
        adv_loss = F.cross_entropy(adv_outputs, targets)
        
        # Compute bound loss around adversarial examples
        bound_loss = self._compute_bound_loss(model, adv_inputs, targets)
        
        # Combine losses
        total_loss = self.alpha * adv_loss + (1 - self.alpha) * bound_loss
        
        return total_loss
    
    def _generate_adversarial_examples(self, model: nn.Module, 
                                     inputs: torch.Tensor,
                                     targets: torch.Tensor) -> torch.Tensor:
        """Generate adversarial examples using PGD."""
        model.eval()
        
        # Initialize perturbation
        delta = torch.zeros_like(inputs)
        delta.requires_grad_(True)
        
        for _ in range(self.num_steps):
            # Forward pass
            outputs = model(inputs + delta)
            loss = F.cross_entropy(outputs, targets)
            
            # Backward pass
            loss.backward()
            
            # Update perturbation
            with torch.no_grad():
                grad_sign = delta.grad.sign()
                delta += self.step_size * grad_sign
                
                # Project to L∞ ball
                delta = torch.clamp(delta, -self.epsilon, self.epsilon)
                
                # Project to valid input range [0, 1]
                delta = torch.clamp(inputs + delta, 0, 1) - inputs
            
            # Clear gradients
            delta.grad.zero_()
        
        model.train()
        return inputs + delta.detach()
    
    def _compute_bound_loss(self, model: nn.Module, 
                          adv_inputs: torch.Tensor,
                          targets: torch.Tensor) -> torch.Tensor:
        """Compute bound-based loss around adversarial examples."""
        # Use smaller IBP box around adversarial examples
        ibp_epsilon = self.beta * self.epsilon
        
        # Compute bounds
        lower_bounds, upper_bounds = self.bound_computer.compute_ibp_bounds(
            model, adv_inputs, ibp_epsilon
        )
        
        # Compute worst-case loss over the bound
        batch_size = adv_inputs.size(0)
        num_classes = lower_bounds.size(1)
        
        # Create one-hot target tensor
        target_onehot = torch.zeros(batch_size, num_classes, device=adv_inputs.device)
        target_onehot.scatter_(1, targets.unsqueeze(1), 1)
        
        # Compute margin-based loss
        # For correct class: use lower bound (worst case)
        # For incorrect classes: use upper bound (worst case)
        correct_class_bounds = (target_onehot * lower_bounds).sum(dim=1)
        incorrect_class_bounds = ((1 - target_onehot) * upper_bounds).max(dim=1)[0]
        
        # Margin loss: we want correct class to have higher bound than incorrect
        margin = correct_class_bounds - incorrect_class_bounds
        bound_loss = F.relu(-margin).mean()  # Hinge loss
        
        return bound_loss


def compute_sabr_loss(model: nn.Module, inputs: torch.Tensor,
                     targets: torch.Tensor, epsilon: float,
                     alpha: float = 0.5, **kwargs) -> torch.Tensor:
    """
    Compute SABR loss for given inputs.
    
    Args:
        model: Neural network model
        inputs: Input batch
        targets: Target labels
        epsilon: L∞ perturbation budget
        alpha: Weight balancing adversarial and bound losses
        **kwargs: Additional arguments for SABRLoss
        
    Returns:
        SABR loss value
    """
    sabr_loss = SABRLoss(epsilon=epsilon, alpha=alpha, **kwargs)
    return sabr_loss(model, inputs, targets)


class SABRTrainer:
    """Trainer class for SABR-based certified training."""
    
    def __init__(self, model: nn.Module, epsilon: float, 
                 alpha: float = 0.5, **sabr_kwargs):
        """
        Args:
            model: Model to train
            epsilon: L∞ perturbation budget
            alpha: Weight balancing adversarial and bound losses
            **sabr_kwargs: Additional arguments for SABRLoss
        """
        self.model = model
        self.sabr_loss = SABRLoss(epsilon=epsilon, alpha=alpha, **sabr_kwargs)
    
    def compute_loss(self, inputs: torch.Tensor, 
                    targets: torch.Tensor) -> torch.Tensor:
        """Compute SABR loss for training."""
        return self.sabr_loss(self.model, inputs, targets)
    
    def compute_certified_accuracy(self, inputs: torch.Tensor,
                                 targets: torch.Tensor) -> float:
        """
        Compute certified accuracy using bound propagation.
        
        Args:
            inputs: Input batch
            targets: Target labels
            
        Returns:
            Certified accuracy as fraction
        """
        self.model.eval()
        
        with torch.no_grad():
            # Compute bounds
            lower_bounds, upper_bounds = self.sabr_loss.bound_computer.compute_ibp_bounds(
                self.model, inputs, self.sabr_loss.epsilon
            )
            
            # Check if prediction is certified
            predicted = lower_bounds.argmax(dim=1)
            
            # A prediction is certified if the lower bound of the predicted class
            # is higher than the upper bound of all other classes
            batch_size = inputs.size(0)
            certified = torch.zeros(batch_size, dtype=torch.bool, device=inputs.device)
            
            for i in range(batch_size):
                pred_class = predicted[i]
                pred_lower = lower_bounds[i, pred_class]
                
                # Get upper bounds of all other classes
                other_upper = torch.cat([
                    upper_bounds[i, :pred_class],
                    upper_bounds[i, pred_class+1:]
                ])
                
                # Check if prediction is certified
                if other_upper.numel() > 0:
                    certified[i] = pred_lower > other_upper.max()
                else:
                    certified[i] = True
            
            # Check if certified predictions match targets
            correct_certified = certified & (predicted == targets)
            
        self.model.train()
        return correct_certified.float().mean().item()


if __name__ == "__main__":
    # Test SABR loss
    print("Testing SABR loss implementation...")
    
    # Create test data
    from ..models import create_cnn7_mnist
    
    model = create_cnn7_mnist()
    inputs = torch.randn(4, 1, 28, 28)
    targets = torch.randint(0, 10, (4,))
    epsilon = 0.1
    
    # Test SABR loss
    sabr_loss = SABRLoss(epsilon=epsilon)
    loss = sabr_loss(model, inputs, targets)
    print(f"SABR loss: {loss.item():.4f}")
    
    # Test trainer
    trainer = SABRTrainer(model, epsilon=epsilon)
    loss = trainer.compute_loss(inputs, targets)
    print(f"Trainer loss: {loss.item():.4f}")
    
    # Test certified accuracy
    cert_acc = trainer.compute_certified_accuracy(inputs, targets)
    print(f"Certified accuracy: {cert_acc:.4f}")
    
    print("SABR tests passed!") 