import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights

class ResNet18(nn.Module):
    def __init__(self, num_classes=10, pretrained=False):
        super(ResNet18, self).__init__()

        if pretrained:
            weights = ResNet18_Weights.DEFAULT
            self.model = resnet18(weights=weights)
        else:
            self.model = resnet18(weights=None)

        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)


#
# def cfgRes(depth):
#     depth_lst = [18, 34, 50, 101, 152]
#     assert (depth in depth_lst), "Error: Resnet depth should be one of [18,34,50,101,152]"
#     cf_dict = {
#         '18': (BasicBlock, [2, 2, 2, 2]),
#         '34': (BasicBlock, [3, 4, 6, 3]),
#         '50': (Bottleneck, [3, 4, 6, 3]),
#         '101': (Bottleneck, [3, 4, 23, 3]),
#         '152': (Bottleneck, [3, 8, 36, 3]),
#     }
#     return cf_dict[str(depth)]
#
# class BasicBlock(nn.Module):
#     expansion = 1
#
#     def __init__(self, in_planes, planes, stride=1):
#         super(BasicBlock, self).__init__()
#         self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
#                                stride=stride, padding=1, bias=False)
#         self.bn1   = nn.BatchNorm2d(planes)
#         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
#                                stride=1, padding=1, bias=False)
#         self.bn2   = nn.BatchNorm2d(planes)
#
#         self.shortcut = nn.Sequential()
#         if stride != 1 or in_planes != planes * BasicBlock.expansion:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_planes, planes * BasicBlock.expansion,
#                           kernel_size=1, stride=stride, bias=False),
#                 nn.BatchNorm2d(planes * BasicBlock.expansion)
#             )
#
#     def forward(self, x):
#         out = F.relu(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += self.shortcut(x)
#         return F.relu(out)
#
# class Bottleneck(nn.Module):
#     expansion = 4
#
#     def __init__(self, in_planes, planes, stride=1):
#         super(Bottleneck, self).__init__()
#         self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
#         self.bn1   = nn.BatchNorm2d(planes)
#         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
#                                stride=stride, padding=1, bias=False)
#         self.bn2   = nn.BatchNorm2d(planes)
#         self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion,
#                                kernel_size=1, bias=False)
#         self.bn3   = nn.BatchNorm2d(planes * Bottleneck.expansion)
#
#         self.shortcut = nn.Sequential()
#         if stride != 1 or in_planes != planes * Bottleneck.expansion:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_planes, planes * Bottleneck.expansion,
#                           kernel_size=1, stride=stride, bias=False),
#                 nn.BatchNorm2d(planes * Bottleneck.expansion)
#             )
#
#     def forward(self, x):
#         out = F.relu(self.bn1(self.conv1(x)))
#         out = F.relu(self.bn2(self.conv2(out)))
#         out = self.bn3(self.conv3(out))
#         out += self.shortcut(x)
#         return F.relu(out)
#
# class ResNet18(nn.Module):
#     def __init__(self, args, num_classes=10):
#         super(ResNet18, self).__init__()
#         self.in_planes = 64
#         self.args = args
#         block, num_blocks = cfgRes(18)
#
#         if args.dataset_list == 'DIGIT10' or args.dataset_list == 'CIFAR10' or args.dataset_list == 'CIFAR100':
#             # CIFAR / Digit10 通常都是 32×32 输入，不做早期下采样
#             self.conv1 = nn.Sequential(
#                 nn.Conv2d(3, self.in_planes,
#                           kernel_size=3, stride=1, padding=1, bias=False),
#                 nn.BatchNorm2d(self.in_planes),
#                 nn.ReLU(inplace=True)
#             )
#
#             # 2. ImageNet-Style for Office-31 或 Office-Caltech10
#         elif args.dataset_list == 'OFFICE31' or args.dataset_list == 'OFFICE_CALTECH_10':
#             # 如果你想完全拷贝官方 ResNet-18 的 Stem，就用下面这一套
#             self.conv1 = nn.Sequential(
#                 nn.Conv2d(3, self.in_planes,
#                           kernel_size=7, stride=2, padding=3, bias=False),
#                 nn.BatchNorm2d(self.in_planes),
#                 nn.ReLU(inplace=True),
#                 nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
#             )
#
#         # # --- Stem ---
#         # self.conv1 = nn.Conv2d(3, 64,
#         #                        kernel_size=3,
#         #                        stride=1,
#         #                        padding=1,
#         #                        bias=False)
#         # self.bn1   = nn.BatchNorm2d(64)
#
#         # --- Residual Layers ---
#         self.layer1 = self._make_layer(block,  64, num_blocks[0], stride=1)
#         self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
#         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#         # --- 自适应全局池化 & 全连接分类层 ---
#         # 任何输入尺寸经过 layer4 后都会变成 [batch, 512, H', W']，H',W' 不同
#         # adaptive_avg_pool2d(..., 1) → [batch, 512, 1, 1]
#         # 展平后就是 [batch, 512]
#         self.classifier = nn.Linear(512 * block.expansion, num_classes)
#
#     def _make_layer(self, block, planes, num_blocks, stride):
#         strides = [stride] + [1] * (num_blocks - 1)
#         layers = []
#         for s in strides:
#             layers.append(block(self.in_planes, planes, s))
#             self.in_planes = planes * block.expansion
#         return nn.Sequential(*layers)
#
#     def forward(self, x):
#         # x: [batch, 3, H, W]，H,W 可以是任意正整数（只要 >= 32）
#         out = self.conv1(x)
#         out = self.layer1(out)
#         out = self.layer2(out)
#         out = self.layer3(out)
#         out = self.layer4(out)
#         # ---- 关键：自适应池化到 1×1 ----
#         out = self.avgpool(out)    # → [batch, 512, 1, 1]
#         out = torch.flatten(out, 1)         # → [batch, 512]
#         out = self.classifier(out)             # → [batch, num_classes]
#         return out