# -*- coding: UTF-8 -*-


import torch.nn as nn
from torchvision.models import resnet
from torchvision.models.resnet import conv3x3, conv1x1, BasicBlock, Bottleneck
from torch.hub import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
from torch import Tensor

settings = {
    'resnet18': [[2, 2, 2, 2], BasicBlock],
    'resnet34': [[3, 4, 6, 3], BasicBlock],
    'resnet50': [[3, 4, 6, 3], Bottleneck],
}

model_urls = {
    "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
    "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
    "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
    "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
    "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}


def _resnet(
        arch: str,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        pretrained: bool,
        progress: bool,
        **kwargs: Any,
) -> resnet.ResNet:
    model = resnet.ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        model.load_state_dict(state_dict)
    return model


class ResNet(object):
    """
    Defines a ResNet model class that can be used to create ResNet models with different configurations.
    
    The `ResNet` class takes in a `net_name` parameter that specifies the ResNet model to use (e.g. 'resnet18', 'resnet34', etc.), as well as optional `cifar` and `preact` flags to configure the model.
    
    The `__call__` method is used to create a ResNet model instance with the specified configuration. It takes in optional `pretrained` and `progress` parameters to control whether a pre-trained model should be loaded and whether to display progress during model loading.
    
    The `__call__` method returns a PyTorch `nn.Sequential` module that represents the ResNet model, with any necessary modifications for the CIFAR dataset or pre-activation blocks.
    """

    def __init__(self,
                 net_name,
                 cifar=False,
                 preact=False):
        """
        Initializes a ResNet model with the specified name, CIFAR-10 dataset flag, and pre-activation flag.
        
        Args:
            net_name (str): The name of the ResNet model to initialize.
            cifar (bool, optional): Whether to use the CIFAR-10 dataset. Defaults to False.
            preact (bool, optional): Whether to use the pre-activation version of ResNet. Defaults to False.
        """
        
        self.net_name = net_name
        self.cifar = cifar
        self.preact = preact

    def __call__(self, pretrained: bool = False, progress: bool = True, **kwargs):
        """
        Applies the ResNet model to the input tensor.
        
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
            **kwargs: Additional keyword arguments to pass to the ResNet model constructor
        
        Returns:
            nn.Sequential: A PyTorch module representing the ResNet model.
        """

        layers, block = settings[self.net_name]
        kwargs.update({
            'arch': self.net_name,
            'layers': layers,
            'block': block,
        })
        if self.preact:
            kwargs['block'] = PreActBasicBlock
        model = _resnet(pretrained=pretrained, progress=progress, **kwargs)
        nets = []
        for name, module in model.named_children():
            if self.cifar:
                if name == 'conv1':
                    module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                if isinstance(module, nn.MaxPool2d):
                    continue
            if isinstance(module, nn.Linear):
                nets.append(nn.Flatten(1))
                continue
            nets.append(module)

        model = nn.Sequential(*nets)
        return model


class PreActBasicBlock(BasicBlock):
    """
    A pre-activation basic block for a ResNet model.
    
    This block is a variant of the basic ResNet block that applies batch normalization and the ReLU activation function before the convolution layers, rather than after. This can improve the model's performance in some cases.
    
    The block takes in an input tensor `x` and applies the following operations:
    
    1. Batch normalization on the input tensor `x`.
    2. ReLU activation on the normalized tensor.
    3. Convolution with stride `stride` on the activated tensor.
    4. Batch normalization on the output of the first convolution.
    5. ReLU activation on the normalized tensor.
    6. Convolution with stride 1 on the activated tensor.
    7. Addition of the input tensor `x` (after optional downsampling) to the output of the second convolution.
    
    The block also has an `expansion` attribute that is used to determine the number of output channels for the convolution layers.
    
    Args:
        inplanes (int): Number of input channels.
        planes (int): Number of output channels.
        stride (int): Stride of the first convolution layer.
        downsample (nn.Module, optional): Module to downsample the input tensor if necessary.
        groups (int): Number of groups in the convolution layers.
        base_width (int): Base width of the convolution layers.
        dilation (int): Dilation rate of the convolution layers.
        norm_layer (Callable[..., nn.Module], optional): Normalization layer to use.
    
    Returns:
        Tensor: Output tensor after applying the pre-activation basic block.
    """

    expansion = 1

    def __init__(
            self,
            inplanes: int,
            planes: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            groups: int = 1,
            base_width: int = 64,
            dilation: int = 1,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        """
        Initializes a PreActBasicBlock instance.
        
        Args:
            inplanes (int): The number of input channels.
            planes (int): The number of output channels.
            stride (int, optional): The stride of the convolution. Default is 1.
            downsample (nn.Module, optional): An optional module to downsample the input.
            groups (int, optional): The number of blocked connections from input channels to output channels. Default is 1.
            base_width (int, optional): The base width of the channels. Default is 64.
            dilation (int, optional): The dilation rate of the convolution. Default is 1.
            norm_layer (Callable[..., nn.Module], optional): The normalization layer to use.
        """
        
        super(PreActBasicBlock, self).__init__(inplanes, planes, stride, downsample, groups, base_width, dilation,
                                               norm_layer)
        self.bn1 = norm_layer(inplanes)
        if self.downsample is not None:
            self.downsample = self.downsample[0]  # remove norm

    def forward(self, x: Tensor) -> Tensor:
        """
        Performs a forward pass through a ResNet block.
        
        Args:
            x (torch.Tensor): The input tensor to the ResNet block.
        
        Returns:
            torch.Tensor: The output tensor after the forward pass.
        """

        identity = x

        # print(x.size())
        out = self.bn1(x)
        out = self.relu(out)
        if self.downsample is not None:
            identity = self.downsample(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out += identity

        return out


"""
Demonstrates the usage of a ResNet model on the CIFAR dataset.

The code creates a ResNet-18 model, prints its architecture, and runs a sample input through the model.
"""
if __name__ == '__main__':
    
    model = ResNet('resnet18',
                   cifar=True,
                   preact=True)
    model = model()
    print(model)
    import torch

    inputs = torch.randn(2, 3, 32, 32)
    print(model(inputs))
