from itertools import chain

import torch
import torch.nn as nn
import torch.nn.functional as F

import errno
import hashlib
import os
import warnings
import re
import shutil
import sys
import tempfile
from tqdm import tqdm
from urllib.request import urlopen
from urllib.parse import urlparse  # noqa: F401


def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
    r"""Loads the Torch serialized object at the given URL.

    If the object is already present in `model_dir`, it's deserialized and
    returned. The filename part of the URL should follow the naming convention
    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
    digits of the SHA256 hash of the contents of the file. The hash is used to
    ensure unique names and to verify the contents of the file.

    The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesytem layout, with a default value ``~/.cache`` if not set.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename).group(1)
        _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return torch.load(cached_file, map_location=map_location)


def _download_url_to_file(url, dst, hash_prefix, progress):
    file_size = None
    u = urlopen(url)
    meta = u.info()
    if hasattr(meta, 'getheaders'):
        content_length = meta.getheaders("Content-Length")
    else:
        content_length = meta.get_all("Content-Length")
    if content_length is not None and len(content_length) > 0:
        file_size = int(content_length[0])

    # We deliberately save it in a temp file and move it after
    # download is complete. This prevents a local working checkpoint
    # being overriden by a broken download.
    dst_dir = os.path.dirname(dst)
    f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)

    try:
        if hash_prefix is not None:
            sha256 = hashlib.sha256()
        with tqdm(total=file_size, disable=not progress,
                  unit='B', unit_scale=True, unit_divisor=1024) as pbar:
            while True:
                buffer = u.read(8192)
                if len(buffer) == 0:
                    break
                f.write(buffer)
                if hash_prefix is not None:
                    sha256.update(buffer)
                pbar.update(len(buffer))

        f.close()
        if hash_prefix is not None:
            digest = sha256.hexdigest()
            if digest[:len(hash_prefix)] != hash_prefix:
                raise RuntimeError('invalid hash value (expected "{}", got "{}")'
                                   .format(hash_prefix, digest))
        shutil.move(f.name, dst)
    finally:
        f.close()
        if os.path.exists(f.name):
            os.remove(f.name)


ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')


def _get_torch_home():
    torch_home = os.path.expanduser(
        os.getenv(ENV_TORCH_HOME,
                  os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
    return torch_home


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class conv_block(nn.Module):
    def __init__(self, in_planes, planes, stride=1, mode='normal'):
        super(conv_block, self).__init__()
        self.conv = conv3x3(in_planes, planes, stride)
        self.mode = mode
        if self.mode == 'parallel_adapters':
            self.adapter = conv3x3(in_planes, planes, stride)
        else:
            self.adapter = None

    def forward(self, x):
        y = self.conv(x)
        if self.adapter is not None:
            y = y + self.adapter(x)
        return y


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


# class BasicBlock(nn.Module):
#     expansion = 1
#
#     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
#                  base_width=64, dilation=1, norm_layer=None):
#         super(BasicBlock, self).__init__()
#         if norm_layer is None:
#             norm_layer = nn.BatchNorm2d
#         self.conv1 = conv3x3(inplanes, planes, stride)
#         self.bn1 = norm_layer(planes)
#         self.relu = nn.ReLU(inplace=True)
#         self.conv2 = conv3x3(planes, planes)
#         self.bn2 = norm_layer(planes)
#         self.downsample = downsample
#         self.stride = stride
#         self.adapter1 = AdapterLayer(inplanes, planes, stride)
#         # self.adapter1 = None
#         self.adapter2 = AdapterLayer(planes, planes)
#
#     def forward(self, x):
#         current_mode = mode_context.get()
#         identity = x
#
#         if current_mode == 'parallel_adapters' and self.adapter1 is not None:
#             out = self.conv1(x) + self.adapter1(x)
#         else:
#             out = self.conv1(x)
#
#         out = self.bn1(out)
#         out = self.relu(out)
#
#         if current_mode == 'parallel_adapters':
#             out = self.conv2(out) + self.adapter2(out)
#         else:
#             out = self.conv2(out)
#
#         out = self.bn2(out)
#
#         if self.downsample is not None:
#             identity = self.downsample(x)
#
#         out += identity
#         out = self.relu(out)
#
#         return out
from contextvars import ContextVar

# This context variable will default to 'normal' if not set
mode_context = ContextVar('mode', default='normal')


class BottleneckAdapterLayer(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(BottleneckAdapterLayer, self).__init__()
        self.adapter1 = nn.Conv2d(inplanes, inplanes // 2, kernel_size=1, stride=1, padding=0, bias=False)
        self.adapter2 = nn.Conv2d(inplanes // 2, inplanes // 2, kernel_size=3, stride=stride, padding=1, bias=False)
        self.adapter3 = nn.Conv2d(inplanes // 2, planes, kernel_size=1, stride=1, padding=0, bias=False)

        # 初始化权重
        # nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        # nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
        # nn.init.kaiming_normal_(self.conv3.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        out = self.adapter1(x)
        out = self.adapter2(out)
        out = self.adapter3(out)
        return out


class AdapterLayer(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(AdapterLayer, self).__init__()
        # self.adapter = conv1x1(inplanes, planes, stride)
        self.adapter = conv1x1(inplanes, planes, stride)
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, x):
        x = self.adapter(x)
        # x= self.bn(x)
        return x


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, include_adapter=False,
                 groups=1, base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.include_adapter = include_adapter
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        if self.include_adapter:
            # self.adapter1 = BottleneckAdapterLayer(inplanes, planes, stride)
            # self.adapter2 = BottleneckAdapterLayer(planes, planes)
            self.adapter1 = AdapterLayer(inplanes, planes, stride)
            self.adapter2 = AdapterLayer(planes, planes)
            self.adapter3 = None

        else:
            self.adapter1 = None
            self.adapter2 = None
            self.adapter3 = None
        self.alpha = nn.Parameter(torch.tensor([0.15]), requires_grad=False)
        self.beta = nn.Parameter(torch.tensor([1.0]), requires_grad=False)

    def forward(self, x):
        identity = x

        # First convolutional layer
        out = self.conv1(x)
        if self.include_adapter and self.adapter1 is not None and mode_context.get() == 'parallel_adapters':
            adapter_out = self.adapter1(x)  # Compute adapter1 output
            out = out + self.alpha * adapter_out  # Adjust out with adapter1 output
        out = self.bn1(out)
        out = self.relu(out)

        # Second convolutional layer and adapter2 in parallel
        conv_out = self.conv2(out)
        adapter2_out = self.adapter2(
            out) if self.include_adapter and self.adapter2 is not None and mode_context.get() == 'parallel_adapters' else torch.zeros_like(
            conv_out)
        #
        # # Combine outputs from conv2 and adapter2
        out = conv_out + self.alpha * adapter2_out

        out = self.bn2(out)

        # Downsample if necessary
        if self.downsample is not None:
            identity = self.downsample(x)

        # Add identity to the output
        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=True,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, mode='parallel_adapters'):
        super(ResNet, self).__init__()
        self.mode = mode
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # self.layer1 = self._make_layer(block, 64, layers[0])
        # self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
        #                                dilate=replace_stride_with_dilation[0])
        # self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
        #                                dilate=replace_stride_with_dilation[1])
        # self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
        #                                dilate=replace_stride_with_dilation[2])
        self.layer1 = self._make_layer(block, 64, layers[0], include_adapter=False)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, include_adapter=True)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, include_adapter=True)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, include_adapter=False)  # Only add adapters here
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes,bias=False)
        #
        # if self.mode == 'parallel_adapters':
        #     for m in self.modules():
        #         if isinstance(m, nn.Conv2d):
        #             m.weight.data.zero_()
        #             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #         elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        #             m.weight.data.zero_()
        #             m.bias.data.zero_()
        #             # nn.init.constant_(m.weight, 1)
        #             # nn.init.constant_(m.bias, 0)
        # else:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        # for m in self.modules():
        #     if isinstance(m, AdapterLayer):
        #         # Zero-initialize adapters
        #         nn.init.constant_(m.adapter.weight, 0)
        #         # continue
        #         # nn.init.constant_(m.bn.weight, 0)
        #         # nn.init.constant_(m.bn.bias, 0)
        #     elif isinstance(m, nn.Conv2d):
        #         # Normal kaiming initialization for conv layers not part of an adapter
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #     elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        #         # Normal constant initialization for batch norm layers
        #         nn.init.constant_(m.weight, 1)
        #         nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    # def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    #     norm_layer = self._norm_layer
    #     downsample = None
    #     previous_dilation = self.dilation
    #     if dilate:
    #         self.dilation *= stride
    #         stride = 1
    #     if stride != 1 or self.inplanes != planes * block.expansion:
    #         downsample = nn.Sequential(
    #             conv1x1(self.inplanes, planes * block.expansion, stride),
    #             norm_layer(planes * block.expansion),
    #         )
    #
    #     layers = []
    #     layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
    #                         self.base_width, previous_dilation, norm_layer))
    #     self.inplanes = planes * block.expansion
    #     for _ in range(1, blocks):
    #         layers.append(block(self.inplanes, planes, groups=self.groups,
    #                             base_width=self.base_width, dilation=self.dilation,
    #                             norm_layer=norm_layer))
    #
    #     return nn.Sequential(*layers)
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, include_adapter=False):
        norm_layer = self._norm_layer  # Get the normalization layer type (usually BatchNorm2d)
        downsample = None  # This will hold the downsampling operation if needed
        previous_dilation = self.dilation  # Current dilation factor (important for handling dilated convolutions)

        # If dilation is to be applied, adjust the dilation factor and set stride to 1
        if dilate:
            self.dilation *= stride
            stride = 1

        # If the stride is not 1 or the number of input planes does not match the output planes, create a downsampling layer
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),  # 1x1 convolution for changing dimensions
                norm_layer(planes * block.expansion),  # Normalization layer
            )

        layers = []  # This will hold all the blocks for this layer
        for i in range(blocks):
            # Include the adapter only in the last block of the final layer
            # include_adapter_flag = include_adapter if i == blocks - 1 else False
            include_adapter_flag = include_adapter
            # Add a block to the layer
            layers.append(block(self.inplanes, planes, stride if i == 0 else 1,
                                downsample if i == 0 else None,
                                include_adapter=include_adapter_flag,
                                groups=self.groups, base_width=self.base_width,
                                dilation=self.dilation, norm_layer=norm_layer))
            self.inplanes = planes * block.expansion  # Update inplanes for the next block

        return nn.Sequential(*layers)  # Return the sequence of blocks as a single module

    # def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    #     norm_layer = self._norm_layer
    #     downsample = None
    #     previous_dilation = self.dilation
    #     if dilate:
    #         self.dilation *= stride
    #         stride = 1
    #     if stride != 1 or self.inplanes != planes * block.expansion:
    #         downsample = nn.Sequential(
    #             conv1x1(self.inplanes, planes * block.expansion, stride),
    #             norm_layer(planes * block.expansion),
    #         )
    #
    #     layers = []
    #     # 加入AdapterLayer前的准备
    #     adapter = AdapterLayer(planes * block.expansion, mode=self.mode)  # 假设mode被保存在ResNet对象中
    #     for _ in range(blocks):
    #         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
    #                             self.base_width, previous_dilation, norm_layer))
    #         if self.mode == 'parallel_adapters':
    #             # 仅当模式是parallel_adapters时添加adapter层
    #             layers.append(adapter)
    #         self.inplanes = planes * block.expansion
    #         stride = 1  # 仅第一个块可能有stride!=1
    #         downsample = None  # 仅第一个块可能需要downsample
    #
    #     return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)
        # x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1))
        # x = temperature * x

        return x


# class ResNet(nn.Module):
#
#     def __init__(self, block, layers, num_classes=100, mode='normal'):
#         self.inplanes = 64
#         super(ResNet, self).__init__()
#         self.mode = mode
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
#                                bias=True)
#         self.relu = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         self.layer1 = self._make_layer(block, 64, layers[0])
#         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
#         self.feature = nn.AvgPool2d(4, stride=1)
#         self.feature_dim = 512
#
#         if self.mode == 'parallel_adapters':
#             for m in self.modules():
#                 if isinstance(m, nn.Conv2d):
#                     m.weight.data.zero_()
#                 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
#                     m.weight.data.zero_()
#                     m.bias.data.zero_()
#         else:
#             for m in self.modules():
#                 if isinstance(m, nn.Conv2d):
#                     nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                 elif isinstance(m, nn.BatchNorm2d):
#                     nn.init.constant_(m.weight, 1)
#                     nn.init.constant_(m.bias, 0)
#
#     def _make_layer(self, block, planes, blocks, stride=1):
#         downsample = None
#         if stride != 1 or self.inplanes != planes * block.expansion:
#             downsample = nn.Sequential(
#                 nn.Conv2d(self.inplanes, planes * block.expansion,
#                           kernel_size=1, stride=stride, bias=True),
#             )
#         layers = []
#         layers.append(block(self.inplanes, planes, self.mode, stride, downsample))
#         self.inplanes = planes * block.expansion
#         for i in range(1, blocks):
#             layers.append(block(self.inplanes, planes, self.mode))
#
#         return nn.Sequential(*layers)
#
#     def switch(self, mode='normal'):
#         for name, module in self.named_modules():
#             if hasattr(module, 'mode'):
#                 module.mode = mode
#
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu(x)
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)
#         # dim = x.size()[-1]
#         # pool = nn.AvgPool2d(dim, stride=1)
#         # x = pool(x)
#         # x = x.view(x.size(0), -1)
#         return x


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        model_dict = model.state_dict()
        # state_dict = load_state_dict_from_url(model_urls[arch],
        #                                       progress=progress)
        # state_dict = {k: v for k, v in state_dict.items() if k not in ['fc.weight', 'fc.bias']}
        # model_dict.update(state_dict)
        # model.load_state_dict(model_dict,strict=False)

        # 加载更新后的state_dict到模型中
        model.load_state_dict(model_dict)
        # 加载预训练模型的state_dict
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        # 移除最后的fc层权重（如果有需要）
        state_dict = {k: v for k, v in state_dict.items() if k not in ['fc.weight', 'fc.bias']}

        # 筛选出与model匹配的预训练权重
        pretrained_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
        # print(pretrained_dict.keys())

        # 更新当前模型的state_dict
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    return model


# 加载预训练模型状态字典
# state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
#
# # 过滤掉预训练模型中不存在的参数
# # 这确保我们仅尝试更新那些在预训练模型中存在的参数
# pretrained_dict = {k: v for k, v in state_dict.items() if k in model.state_dict() and 'adapter' not in k}
#
# # 从模型的当前状态字典中过滤出那些不需要更新的参数（即自定义层的参数）
# # 这一步是可选的，取决于您是否在自定义层上执行了特殊的初始化
# model_dict = model.state_dict()
# model_dict.update(pretrained_dict)  # 更新现有模型状态字典
#
# # 加载更新后的状态字典到模型中
# model.load_state_dict(model_dict, strict=True)
def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)


def resnet152(pretrained=False, progress=True, **kwargs):
    r"""ResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)


def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)
