"""
Bound computation utilities for certified training.

This module implements various bound propagation techniques including
Interval Bound Propagation (IBP) and integration with Auto-LiRPA.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Dict, List, Optional, NamedTuple
from dataclasses import dataclass


@dataclass
class CertificationResult:
    """Result of certification check."""
    is_certified: bool
    lower_bounds: torch.Tensor
    upper_bounds: torch.Tensor
    predicted_class: int
    confidence_margin: float


class BoundComputation:
    """
    Bound computation utilities for neural networks.
    
    Implements IBP and interfaces with Auto-LiRPA for more precise bounds.
    """
    
    def __init__(self, use_auto_lirpa: bool = False):
        """
        Args:
            use_auto_lirpa: Whether to use Auto-LiRPA for bound computation
        """
        self.use_auto_lirpa = use_auto_lirpa
        if use_auto_lirpa:
            try:
                import auto_LiRPA
                self.auto_lirpa = auto_LiRPA
            except ImportError:
                print("Warning: Auto-LiRPA not available, falling back to IBP")
                self.use_auto_lirpa = False
    
    def compute_ibp_bounds(self, model: nn.Module, inputs: torch.Tensor,
                          epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute IBP bounds for the model output.
        
        Args:
            model: Neural network model
            inputs: Input batch
            epsilon: L∞ perturbation budget
            
        Returns:
            Tuple of (lower_bounds, upper_bounds) for model outputs
        """
        if self.use_auto_lirpa:
            return self._compute_auto_lirpa_bounds(model, inputs, epsilon)
        else:
            return self._compute_manual_ibp_bounds(model, inputs, epsilon)
    
    def _compute_manual_ibp_bounds(self, model: nn.Module, inputs: torch.Tensor,
                                 epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
        """Manual IBP implementation for simple networks."""
        # Create input bounds
        lower_input = torch.clamp(inputs - epsilon, 0, 1)
        upper_input = torch.clamp(inputs + epsilon, 0, 1)
        
        # For simple IBP, we'll use a forward pass approach
        # This is a simplified implementation - for complex networks,
        # use Auto-LiRPA or more sophisticated bound propagation
        
        lower_bounds = lower_input
        upper_bounds = upper_input
        
        # Propagate through each layer
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d):
                lower_bounds, upper_bounds = self._propagate_conv2d(
                    module, lower_bounds, upper_bounds
                )
            elif isinstance(module, nn.Linear):
                lower_bounds, upper_bounds = self._propagate_linear(
                    module, lower_bounds, upper_bounds
                )
            elif isinstance(module, nn.ReLU):
                lower_bounds, upper_bounds = self._propagate_relu(
                    lower_bounds, upper_bounds
                )
            elif isinstance(module, nn.Flatten) or 'flatten' in name.lower():
                batch_size = lower_bounds.size(0)
                lower_bounds = lower_bounds.view(batch_size, -1)
                upper_bounds = upper_bounds.view(batch_size, -1)
        
        return lower_bounds, upper_bounds
    
    def _propagate_conv2d(self, layer: nn.Conv2d, lower: torch.Tensor,
                         upper: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Propagate bounds through Conv2d layer."""
        # Get positive and negative weights
        weight_pos = torch.clamp(layer.weight, min=0)
        weight_neg = torch.clamp(layer.weight, max=0)
        
        # Compute bounds
        lower_new = (F.conv2d(lower, weight_pos, stride=layer.stride,
                             padding=layer.padding) +
                    F.conv2d(upper, weight_neg, stride=layer.stride,
                             padding=layer.padding))
        
        upper_new = (F.conv2d(upper, weight_pos, stride=layer.stride,
                             padding=layer.padding) +
                    F.conv2d(lower, weight_neg, stride=layer.stride,
                             padding=layer.padding))
        
        # Add bias if present
        if layer.bias is not None:
            bias = layer.bias.view(1, -1, 1, 1)
            lower_new += bias
            upper_new += bias
        
        return lower_new, upper_new
    
    def _propagate_linear(self, layer: nn.Linear, lower: torch.Tensor,
                         upper: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Propagate bounds through Linear layer."""
        # Get positive and negative weights
        weight_pos = torch.clamp(layer.weight, min=0)
        weight_neg = torch.clamp(layer.weight, max=0)
        
        # Compute bounds
        lower_new = (F.linear(lower, weight_pos) + F.linear(upper, weight_neg))
        upper_new = (F.linear(upper, weight_pos) + F.linear(lower, weight_neg))
        
        # Add bias if present
        if layer.bias is not None:
            lower_new += layer.bias
            upper_new += layer.bias
        
        return lower_new, upper_new
    
    def _propagate_relu(self, lower: torch.Tensor,
                       upper: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Propagate bounds through ReLU activation."""
        # ReLU bounds:
        # - If upper <= 0: output is 0
        # - If lower >= 0: output is identity
        # - If lower < 0 < upper: output lower is 0, upper is original
        
        lower_new = torch.clamp(lower, min=0)
        upper_new = torch.clamp(upper, min=0)
        
        return lower_new, upper_new
    
    def _compute_auto_lirpa_bounds(self, model: nn.Module, inputs: torch.Tensor,
                                 epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute bounds using Auto-LiRPA (if available)."""
        if not self.use_auto_lirpa:
            raise RuntimeError("Auto-LiRPA not available")
        
        # This is a placeholder for Auto-LiRPA integration
        # In practice, you would use Auto-LiRPA's BoundedModule
        # and BoundedTensor for more precise bound computation
        
        # Fallback to manual IBP for now
        return self._compute_manual_ibp_bounds(model, inputs, epsilon)
    
    def certify_prediction(self, model: nn.Module, inputs: torch.Tensor,
                          epsilon: float, target_class: Optional[int] = None
                          ) -> CertificationResult:
        """
        Certify if model prediction is robust.
        
        Args:
            model: Neural network model
            inputs: Input batch (should be single example)
            epsilon: L∞ perturbation budget  
            target_class: Expected class (if None, use model prediction)
            
        Returns:
            CertificationResult with certification status and bounds
        """
        if inputs.dim() == 3:  # Add batch dimension if needed
            inputs = inputs.unsqueeze(0)
        
        # Compute bounds
        lower_bounds, upper_bounds = self.compute_ibp_bounds(model, inputs, epsilon)
        
        # Get predicted class
        with torch.no_grad():
            outputs = model(inputs)
            predicted_class = outputs.argmax(dim=1).item()
        
        if target_class is None:
            target_class = predicted_class
        
        # Check certification
        target_lower = lower_bounds[0, target_class]
        
        # Get maximum upper bound of other classes
        other_classes = [i for i in range(lower_bounds.size(1)) if i != target_class]
        if other_classes:
            max_other_upper = upper_bounds[0, other_classes].max()
            is_certified = target_lower > max_other_upper
            confidence_margin = (target_lower - max_other_upper).item()
        else:
            is_certified = True
            confidence_margin = float('inf')
        
        return CertificationResult(
            is_certified=is_certified,
            lower_bounds=lower_bounds[0],
            upper_bounds=upper_bounds[0],
            predicted_class=predicted_class,
            confidence_margin=confidence_margin
        )
    
    def compute_certified_accuracy(self, model: nn.Module, data_loader,
                                 epsilon: float) -> float:
        """
        Compute certified accuracy over a dataset.
        
        Args:
            model: Neural network model
            data_loader: PyTorch DataLoader
            epsilon: L∞ perturbation budget
            
        Returns:
            Certified accuracy as fraction
        """
        model.eval()
        certified_correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in data_loader:
                batch_size = inputs.size(0)
                
                for i in range(batch_size):
                    result = self.certify_prediction(
                        model, inputs[i], epsilon, targets[i].item()
                    )
                    
                    if result.is_certified and result.predicted_class == targets[i].item():
                        certified_correct += 1
                    
                    total += 1
        
        return certified_correct / total if total > 0 else 0.0


def compute_margin_loss(lower_bounds: torch.Tensor, upper_bounds: torch.Tensor,
                       targets: torch.Tensor) -> torch.Tensor:
    """
    Compute margin-based loss for certification.
    
    Args:
        lower_bounds: Lower bounds for each class
        upper_bounds: Upper bounds for each class
        targets: Target class indices
        
    Returns:
        Margin loss tensor
    """
    batch_size = targets.size(0)
    num_classes = lower_bounds.size(1)
    
    # Create mask for target classes
    target_mask = torch.zeros_like(lower_bounds)
    target_mask.scatter_(1, targets.unsqueeze(1), 1)
    
    # Target class lower bounds
    target_lower = (target_mask * lower_bounds).sum(dim=1)
    
    # Non-target classes upper bounds
    non_target_upper = ((1 - target_mask) * upper_bounds).max(dim=1)[0]
    
    # Margin: target should be higher than non-targets
    margin = target_lower - non_target_upper
    
    # Hinge loss
    loss = F.relu(-margin).mean()
    
    return loss


if __name__ == "__main__":
    # Test bound computation
    print("Testing bound computation...")
    
    from ..models import create_cnn7_mnist
    
    # Create test model and data
    model = create_cnn7_mnist()
    inputs = torch.randn(2, 1, 28, 28)
    epsilon = 0.1
    
    # Test bound computation
    bound_computer = BoundComputation()
    lower_bounds, upper_bounds = bound_computer.compute_ibp_bounds(model, inputs, epsilon)
    
    print(f"Input shape: {inputs.shape}")
    print(f"Lower bounds shape: {lower_bounds.shape}")
    print(f"Upper bounds shape: {upper_bounds.shape}")
    
    # Test certification
    single_input = inputs[0]
    result = bound_computer.certify_prediction(model, single_input, epsilon)
    print(f"Certification result: {result.is_certified}")
    print(f"Predicted class: {result.predicted_class}")
    print(f"Confidence margin: {result.confidence_margin:.4f}")
    
    print("Bound computation tests passed!") 