# Adapted from XXXX

import torch.nn as nn


class ResnetBlockConv(nn.Module):
    """
    Convolutional ResNet block class for processing image data.

    Parameters:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int): kernel size
        stride (int): stride
        padding (int): padding size
        activation (nn.Module): activation function
    """

    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        padding=1,
        activation=nn.LeakyReLU,
    ):
        super().__init__()

        # If output channel number is not specified, use input channel number
        if out_channels is None:
            out_channels = in_channels

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.activation = activation()

        # Main branch
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        )

        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            bias=False,
        )

        # Residual connection
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        # Main branch
        identity = x
        out = self.conv1(x)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        # Residual connection
        shortcut = self.shortcut(identity)
        out = out + shortcut

        # Final activation
        out = self.activation(out)

        return out


class ResNetFeatureExtractor(nn.Module):
    """
    Feature extractor, converts RGB images to feature tensor

    Parameters:
        input_channels (int): number of input channels, default is 3 (RGB)
        initial_features (int): initial number of features, default is 8
    """

    def __init__(
        self,
        input_channels=3,
        initial_features=8,
        image_size=128,
        activation=nn.LeakyReLU,
    ):
        super().__init__()

        self.activation = activation()

        self.initial_conv = nn.Conv2d(
            input_channels, initial_features, kernel_size=3, stride=1, padding=1
        )
        self.initial_layer_norm = nn.LayerNorm(
            [initial_features, image_size, image_size]
        )
        self.initial_pool = nn.MaxPool2d(kernel_size=3, stride=3)

        # First stage: 3 ResNet blocks
        self.resnet_block_1 = ResnetBlockConv(initial_features, 16, stride=2)
        self.resnet_block_2 = ResnetBlockConv(16, 32, stride=2)
        self.resnet_block_3 = ResnetBlockConv(32, 64, stride=2)
        self.resnet_block_4 = ResnetBlockConv(64, 128, stride=2)
        # Final layer
        self.final_layer = nn.Linear(1152, 128)

    def forward(self, x):
        """
        Forward propagation function
        Note: N_C: Max camera number in the batch

        Parameters:
            x: input image with shape [batch*N_C, 3, 128, 128]  Note: Combine multiple cameras in the batch dimension

        Returns:
            feature tensor
        """
        batch_size = x.shape[0]

        # Initial conv and pool [batch*N_C, 3, 128, 128] -> [batch*N_C, 8, 42, 42]
        x = self.activation(
            self.initial_pool(self.initial_layer_norm(self.initial_conv(x)))
        )
        # print(f"x.shape after initial_conv and pool: {x.shape}")
        # [batch*N_C, 8, 42, 42] -> [batch*N_C, 16, 21, 21]
        x = self.resnet_block_1(x)
        # print(f"x.shape after resnet_block_1: {x.shape}")
        # [batch*N_C, 16, 20, 20] -> [batch*N_C, 32, 11, 11]
        x = self.resnet_block_2(x)
        # print(f"x.shape after resnet_block_2: {x.shape}")
        # [batch*N_C, 32, 10, 10] -> [batch*N_C, 64, 6, 6]
        x = self.resnet_block_3(x)
        # print(f"x.shape after resnet_block_3: {x.shape}")
        # [batch*N_C, 64, 6, 6] -> [batch*N_C, 128, 3, 3]
        x = self.resnet_block_4(x)
        # print(f"x.shape after resnet_block_4: {x.shape}")

        # [batch*N_C, 128, 3, 3] -> [batch*N_C, 128]
        features = self.final_layer(x.view(batch_size, -1))
        return features
