"""
CNN7 architecture used in CACTUS experiments.

This is a 7-layer convolutional neural network commonly used in
certified robustness literature for MNIST and CIFAR-10 experiments.
Following Shi et al. (2021).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class CNN7(nn.Module):
    """
    7-layer CNN architecture for MNIST and CIFAR-10.
    
    Architecture:
    - Conv2d (64 filters, 3x3, stride=1, padding=1) + BatchNorm + ReLU
    - Conv2d (64 filters, 3x3, stride=1, padding=1) + BatchNorm + ReLU
    - Conv2d (128 filters, 3x3, stride=2, padding=1) + BatchNorm + ReLU
    - Conv2d (128 filters, 3x3, stride=1, padding=1) + BatchNorm + ReLU
    - Conv2d (128 filters, 3x3, stride=1, padding=1) + BatchNorm + ReLU
    - Flatten
    - Linear (512 hidden units) + ReLU
    - Linear (num_classes output units for classification)
    """
    
    def __init__(self, input_channels=1, num_classes=10, input_size=28):
        super(CNN7, self).__init__()
        
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.input_size = input_size
        
        # 5 Convolutional layers with batch normalization
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        
        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        
        # Calculate the size after convolutions
        # Only conv3 has stride=2, so the spatial dimensions are reduced by factor of 2
        # Final spatial size: input_size // 2
        conv_output_size = (input_size // 2) ** 2 * 128
        
        # Fully connected layers
        self.fc1 = nn.Linear(conv_output_size, 512)
        self.fc2 = nn.Linear(512, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize network weights using Xavier/Glorot initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """Forward pass through the network."""
        # First convolutional block
        x = F.relu(self.bn1(self.conv1(x)))
        
        # Second convolutional block
        x = F.relu(self.bn2(self.conv2(x)))
        
        # Third convolutional block (stride=2)
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Fourth convolutional block
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Fifth convolutional block
        x = F.relu(self.bn5(self.conv5(x)))
        
        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)  # No activation on output layer
        
        return x
    
    def forward_with_bn_eval(self, x):
        """
        Forward pass with batch normalization in evaluation mode.
        Used during PGD attacks as specified in Shi et al. (2021).
        """
        # Set all batch norm layers to eval mode
        self.bn1.eval()
        self.bn2.eval()
        self.bn3.eval()
        self.bn4.eval()
        self.bn5.eval()
        
        with torch.no_grad():
            result = self.forward(x)
        
        # Restore training mode if model was in training mode
        if self.training:
            self.bn1.train()
            self.bn2.train()
            self.bn3.train()
            self.bn4.train()
            self.bn5.train()
        
        return result
    
    def get_activations(self, x, layer_names=None):
        """
        Get intermediate activations for bound computation.
        
        Args:
            x: Input tensor
            layer_names: List of layer names to return activations for
            
        Returns:
            Dictionary mapping layer names to activation tensors
        """
        activations = {}
        
        # Conv1
        x = self.conv1(x)
        activations['conv1_pre_bn'] = x
        x = self.bn1(x)
        activations['conv1_post_bn'] = x
        x = F.relu(x)
        activations['conv1'] = x
        
        # Conv2
        x = self.conv2(x)
        activations['conv2_pre_bn'] = x
        x = self.bn2(x)
        activations['conv2_post_bn'] = x
        x = F.relu(x)
        activations['conv2'] = x
        
        # Conv3
        x = self.conv3(x)
        activations['conv3_pre_bn'] = x
        x = self.bn3(x)
        activations['conv3_post_bn'] = x
        x = F.relu(x)
        activations['conv3'] = x
        
        # Conv4
        x = self.conv4(x)
        activations['conv4_pre_bn'] = x
        x = self.bn4(x)
        activations['conv4_post_bn'] = x
        x = F.relu(x)
        activations['conv4'] = x
        
        # Conv5
        x = self.conv5(x)
        activations['conv5_pre_bn'] = x
        x = self.bn5(x)
        activations['conv5_post_bn'] = x
        x = F.relu(x)
        activations['conv5'] = x
        
        # Flatten
        x = x.view(x.size(0), -1)
        activations['flatten'] = x
        
        # FC1
        x = self.fc1(x)
        activations['fc1_pre_relu'] = x
        x = F.relu(x)
        activations['fc1'] = x
        
        # FC2 (output)
        x = self.fc2(x)
        activations['fc2'] = x
        
        if layer_names is None:
            return activations
        else:
            return {name: activations[name] for name in layer_names if name in activations}


def create_cnn7_mnist():
    """Create CNN7 model for MNIST (1 channel, 28x28 images)."""
    return CNN7(input_channels=1, num_classes=10, input_size=28)


def create_cnn7_cifar10():
    """Create CNN7 model for CIFAR-10 (3 channels, 32x32 images)."""
    return CNN7(input_channels=3, num_classes=10, input_size=32)


if __name__ == "__main__":
    # Test the models
    print("Testing CNN7 architectures...")
    
    # Test MNIST model
    mnist_model = create_cnn7_mnist()
    mnist_input = torch.randn(4, 1, 28, 28)
    mnist_output = mnist_model(mnist_input)
    print(f"MNIST model output shape: {mnist_output.shape}")
    print(f"MNIST model parameters: {sum(p.numel() for p in mnist_model.parameters()):,}")
    
    # Test CIFAR-10 model
    cifar_model = create_cnn7_cifar10()
    cifar_input = torch.randn(4, 3, 32, 32)
    cifar_output = cifar_model(cifar_input)
    print(f"CIFAR-10 model output shape: {cifar_output.shape}")
    print(f"CIFAR-10 model parameters: {sum(p.numel() for p in cifar_model.parameters()):,}")
    
    # Test activations
    activations = mnist_model.get_activations(mnist_input)
    print(f"Available activation layers: {list(activations.keys())}")
    
    # Test BN eval mode functionality
    cifar_model.train()
    output_train = cifar_model(cifar_input)
    output_eval_bn = cifar_model.forward_with_bn_eval(cifar_input)
    print(f"Train mode output norm: {output_train.norm().item():.4f}")
    print(f"BN eval mode output norm: {output_eval_bn.norm().item():.4f}")
    
    print("Model tests passed!") 