# -*- coding: UTF-8 -*-
"""
Applies a ResNet model to a Preact input tensor.

Args:
    x (torch.Tensor): The input tensor to apply the ResNet model to.
    resnet_model (torch.nn.Module): The ResNet model to apply.

Returns:
    torch.Tensor: The output tensor after applying the ResNet model.
"""


import torch
import torch.nn as nn
import torch.nn.functional as F


def conv3x3(in_planes, out_planes, stride=1):
    """
    Applies a 3x3 convolution with the given input and output planes, and an optional stride.
    
    Args:
        in_planes (int): Number of input channels.
        out_planes (int): Number of output channels.
        stride (int, optional): Stride of the convolution. Defaults to 1.
    
    Returns:
        nn.Conv2d: A 2D convolutional layer with the specified parameters.
    """

    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class PreActBlock(nn.Module):
    """
    A pre-activation version of the BasicBlock, which is a building block for the ResNet architecture.
    
    The PreActBlock class implements a pre-activation version of the BasicBlock, which means that the batch normalization and activation function are applied before the convolution operations, rather than after. This can improve the performance and stability of the network during training.
    
    The class takes in the number of input planes (`in_planes`) and the number of output planes (`planes`), as well as an optional stride value. It then applies two 3x3 convolution layers, with batch normalization and ReLU activation in between. If the stride or the number of planes changes, a shortcut connection is added to match the dimensions of the main path.
    
    The `forward` method applies the pre-activation, convolution, and shortcut operations to the input tensor `x` and returns the output.
    """

    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        """
        Applies a ResNet block with a shortcut connection.
        
        Args:
            x (torch.Tensor): The input tensor.
        
        Returns:
            torch.Tensor: The output tensor after applying the ResNet block.
        """

        out = F.relu(self.bn1(x), inplace=True)
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    """
    Implements a pre-activation version of the original Bottleneck module from the ResNet architecture.
    
    The PreActBottleneck module is a building block for the PreAct ResNet model. It applies a sequence of batch normalization, ReLU activation, and convolution layers to the input tensor. The module also includes a shortcut connection that applies a 1x1 convolution to the input tensor if the number of channels or spatial dimensions change.
    
    Args:
        in_planes (int): Number of input channels.
        planes (int): Number of output channels.
        stride (int, optional): Stride of the 3x3 convolution layer. Default is 1.
    
    Returns:
        torch.Tensor: Output tensor after applying the PreActBottleneck module.
    """

    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        """
        Implements a PreActBottleneck block, which is a variant of the Bottleneck block used in ResNet models.
        
        The PreActBottleneck block consists of three convolutional layers with batch normalization and ReLU activation in between. The first and last convolutions have a kernel size of 1, while the middle convolution has a kernel size of 3 and a stride of 1 or the provided stride value.
        
        The shortcut connection is implemented as a separate convolutional layer with a kernel size of 1 and the provided stride, if the input and output channels do not match.
        
        Args:
            in_planes (int): Number of input channels.
            planes (int): Number of output channels.
            stride (int, optional): Stride of the middle convolution. Defaults to 1.
        """

        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        """
        Applies a ResNet-style block to the input tensor `x`.
        
        The block consists of the following operations:
        1. Apply ReLU activation and batch normalization to the input `x`.
        2. Apply the `shortcut` convolution to the normalized input.
        3. Apply the first convolution `conv1`.
        4. Apply ReLU activation and batch normalization to the output of `conv1`.
        5. Apply the second convolution `conv2`.
        6. Apply ReLU activation and batch normalization to the output of `conv2`.
        7. Apply the third convolution `conv3`.
        8. Add the output of the `shortcut` convolution to the output of `conv3`.
        9. Return the final output.
        
        This block is used as part of a larger ResNet-style model.
        """

        out = F.relu(self.bn1(x), inplace=True)
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out = self.conv3(F.relu(self.bn3(out), inplace=True))
        out += shortcut
        return out


class ResNet(nn.Module):
    """
    Defines a ResNet model for image clustering tasks.
    
    The ResNet model is a deep convolutional neural network that uses residual connections to improve training stability and performance. This implementation of ResNet includes several layers with different numbers of residual blocks, allowing for different model depths.
    
    The forward method of the ResNet class takes an input tensor of shape (batch_size, 3, height, width) and returns a tensor of shape (batch_size, num_classes) containing the class logits.
    """

    def __init__(self, block, num_blocks):
        """
        Initializes a ResNet model with the specified block and number of blocks.
        
        The ResNet model is constructed with the following layers:
        - conv1: A 3x3 convolution layer with 64 output channels
        - bn1: A batch normalization layer with 64 channels
        - layer1: A layer of the specified block with 64 output channels and the specified number of blocks
        - layer2: A layer of the specified block with 128 output channels and the specified number of blocks, with a stride of 2
        - layer3: A layer of the specified block with 256 output channels and the specified number of blocks, with a stride of 2
        - layer4: A layer of the specified block with 512 output channels and the specified number of blocks, with a stride of 2
        - bn: A final batch normalization layer with 512 channels
        
        Args:
            block (callable): The block to use for the ResNet layers.
            num_blocks (list[int]): The number of blocks to use for each layer.
        """

        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3, 64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.bn = nn.BatchNorm2d(512)

    def _make_layer(self, block, planes, num_blocks, stride):
        """
        Constructs a sequence of ResNet blocks with the specified number of planes and blocks.
        
        Args:
            block (nn.Module): The ResNet block to use.
            planes (int): The number of output planes for the ResNet blocks.
            num_blocks (int): The number of ResNet blocks to construct.
            stride (int): The stride to use for the first ResNet block.
        
        Returns:
            nn.Sequential: A sequence of ResNet blocks.
        """

        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Performs a forward pass through the ResNet model.
        
        Args:
            x (torch.Tensor): The input tensor to the model.
        
        Returns:
            torch.Tensor: The output tensor after passing through the model.
        """

        out = x
        out = self.conv1(out)
        out = self.bn1(out)
        out = F.relu(out, inplace=True)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.relu(self.bn(out), inplace=True)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return out


def ResNet18():
    """
    Constructs a ResNet-18 model.
    
    Returns:
        A ResNet-18 model instance.
    """

    return ResNet(PreActBlock, [2, 2, 2, 2])


"""
This code is a simple example of using the ResNet18 model from the PyTorch library. It creates an instance of the ResNet18 model, passes a random input tensor of shape (2, 3, 32, 32) through the model, and prints the output tensor size.

This code is intended as a basic demonstration of the ResNet18 model and is not meant to be a complete example of how to use the model in a real-world application. For more advanced usage, please refer to the PyTorch documentation and examples.
"""
if __name__ == '__main__':
    
    net = ResNet18()
    y = net(torch.randn(2, 3, 32, 32))
    print(y.size())
