from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F

from CLIP_utils.misc import freeze_batch_norm_2d

"""
Modified ResNet Model

This implementation adapts the standard ResNet architecture with the following changes:

1. **Modified Stem**:
   - The original single convolution layer is replaced with a 3-layer convolutional stem, with an average pooling layer replacing the max pooling layer.
   - Provides better feature extraction capabilities for image data, particularly at the initial layers.

2. **Anti-Aliasing Strided Convolutions**:
   - Strided convolutions are preceded by an average pooling operation to reduce aliasing artifacts.
   - Improves robustness to input image transformations and variations.

3. **Attention Pooling**:
   - The final average pooling layer is replaced with a QKV attention pooling mechanism.
   - This enables better context-aware global feature aggregation.

4. **Layer Normalization Enhancements**:
   - Batch normalization layers are utilized to stabilize the training process.
   - Includes an option to freeze batch normalization statistics (`freeze_batch_norm_2d`).

5. **Custom Initialization**:
   - Weights for specific layers, including attention projection layers, are initialized with custom schemes.
   - Ensures better convergence properties during training.

This model is designed to serve as a feature extractor for downstream tasks and is fully compatible with PyTorch's `nn.Module`.

Note:
- This implementation integrates seamlessly with other modules such as `AttentionPool2d` for efficient pooling.
- Certain methods like `lock` and `set_grad_checkpointing` allow for easy customization during transfer learning or fine-tuning.

Adapted for integration with the CLIP framework from OpenAI, and modified for better flexibility and feature extraction.
"""


class Bottleneck(nn.Module):
    """
    Bottleneck block for ResNet with anti-aliasing strided convolutions.
    
    This implementation uses an average pooling layer before the strided convolution
    to reduce aliasing effects, improving the model's robustness to input variations.
    
    Attributes:
        expansion (int): Factor by which the number of output channels is expanded 
                         compared to the input channels.
    """
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        """
        Initialize the bottleneck block.
        
        Args:
            inplanes (int): Number of input channels.
            planes (int): Number of intermediate channels.
            stride (int): Stride for the second convolution layer.
        """
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.act1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.act2 = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.act3 = nn.ReLU(inplace=True)

        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        """
        Forward pass of the bottleneck block.
        
        Args:
            x (torch.Tensor): Input tensor.
            
        Returns:
            torch.Tensor: Output tensor after passing through the bottleneck block.
        """
        identity = x

        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.act3(out)
        return out


class AttentionPool2d(nn.Module):
    """
    Attention-based 2D pooling layer that replaces the standard average pooling.
    
    This layer applies a multi-head attention mechanism over spatial dimensions
    to perform more context-aware feature aggregation.
    """
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        """
        Initialize the attention pooling layer.
        
        Args:
            spacial_dim (int): Spatial dimension of the input feature map.
            embed_dim (int): Embedding dimension.
            num_heads (int): Number of attention heads.
            output_dim (int, optional): Output dimension. If None, same as embed_dim.
        """
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        """
        Forward pass of the attention pooling layer.
        
        Args:
            x (torch.Tensor): Input tensor of shape [N, C, H, W].
            
        Returns:
            torch.Tensor: Output tensor after attention pooling.
        """
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0.,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, image_size=224, width=64):
        """
        Initialize the modified ResNet.
        
        Args:
            layers (list): Number of blocks in each layer.
            output_dim (int): Dimension of the output features.
            heads (int): Number of attention heads in the attention pooling layer.
            image_size (int): Size of the input images (assumes square images).
            width (int): Base width of the network.
        """
        super().__init__()
        self.output_dim = output_dim
        self.image_size = image_size

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.act2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.act3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)

        self.init_parameters()

    def _make_layer(self, planes, blocks, stride=1):
        """
        Create a ResNet layer with the specified number of blocks.
        
        Args:
            planes (int): Base width of the layer.
            blocks (int): Number of bottleneck blocks in the layer.
            stride (int): Stride for the first block in the layer.
            
        Returns:
            nn.Sequential: A sequential container of bottleneck blocks.
        """
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def init_parameters(self):
        """
        Initialize model parameters with custom weight initialization.
        
        This method initializes the attention projection layers with normal distribution
        and sets the weights of the final batch normalization layers to zero.
        """
        if self.attnpool is not None:
            std = self.attnpool.c_proj.in_features ** -0.5
            nn.init.normal_(self.attnpool.q_proj.weight, std=std)
            nn.init.normal_(self.attnpool.k_proj.weight, std=std)
            nn.init.normal_(self.attnpool.v_proj.weight, std=std)
            nn.init.normal_(self.attnpool.c_proj.weight, std=std)

        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for name, param in resnet_block.named_parameters():
                if name.endswith("bn3.weight"):
                    nn.init.zeros_(param)

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        """
        Lock the model for fine-tuning by freezing parameters.
        
        Args:
            unlocked_groups (int): Number of groups to leave unlocked. Currently only supports 0.
            freeze_bn_stats (bool): Whether to freeze batch normalization statistics.
        """
        assert unlocked_groups == 0, 'partial locking not currently supported for this model'
        for param in self.parameters():
            param.requires_grad = False
        if freeze_bn_stats:
            freeze_batch_norm_2d(self)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        """
        Enable or disable gradient checkpointing for memory efficiency.
        
        Args:
            enable (bool): Whether to enable gradient checkpointing.
        """
        # FIXME support for non-transformer
        pass

    def stem(self, x):
        """
        Forward pass through the 3-layer stem.
        
        Args:
            x (torch.Tensor): Input image tensor.
            
        Returns:
            torch.Tensor: Output tensor after stem processing.
        """
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(self.conv3(x)))
        x = self.avgpool(x)
        return x

    def forward(self, x):
        """
        Forward pass of the network.
        
        Args:
            x (torch.Tensor): Input image tensor.
            
        Returns:
            torch.Tensor: Output feature tensor.
        """
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)

        return x