{"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"language_info":{"name":"python","version":"3.10.14","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"colab":{"toc_visible":true,"provenance":[],"gpuType":"T4"},"accelerator":"GPU","kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":6799,"databundleVersionId":4225553,"sourceType":"competition"}],"dockerImageVersionId":31090,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"## Initialization","metadata":{"id":"oCZaXC_BndfO"}},{"cell_type":"markdown","source":"### Imports","metadata":{}},{"cell_type":"code","source":"import time\nimport math\nimport os\nimport pickle\nimport json\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchvision\nimport torchvision.transforms as transforms\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom PIL import Image","metadata":{"id":"0C8HttVn_7tn","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:02.131994Z","iopub.execute_input":"2024-12-31T22:31:02.132330Z","iopub.status.idle":"2024-12-31T22:31:06.346477Z","shell.execute_reply.started":"2024-12-31T22:31:02.132300Z","shell.execute_reply":"2024-12-31T22:31:06.345567Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Device","metadata":{}},{"cell_type":"code","source":"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")","metadata":{"id":"Zc_sS6ZVOQID","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:06.347822Z","iopub.execute_input":"2024-12-31T22:31:06.348211Z","iopub.status.idle":"2024-12-31T22:31:06.403852Z","shell.execute_reply.started":"2024-12-31T22:31:06.348183Z","shell.execute_reply":"2024-12-31T22:31:06.402803Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Datasets","metadata":{"id":"_Srf5dXxnhEq"}},{"cell_type":"markdown","source":"### Sampler","metadata":{"id":"7bIxJA5OAAcz"}},{"cell_type":"code","source":"def sample_dataset(full_ds, size):\n    num_samples = int(len(full_ds) * size)\n    random_indices = np.random.choice(len(full_ds), num_samples, replace=False)\n    return torch.utils.data.Subset(full_ds, random_indices)","metadata":{"id":"LU32VnrBDtaq","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:09.029989Z","iopub.execute_input":"2024-12-31T22:31:09.030365Z","iopub.status.idle":"2024-12-31T22:31:09.034902Z","shell.execute_reply.started":"2024-12-31T22:31:09.030334Z","shell.execute_reply":"2024-12-31T22:31:09.033976Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### MNIST","metadata":{"id":"7UYmDC4PnjDO"}},{"cell_type":"code","source":"mnist_transform = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize((0.1307,), (0.3081,))\n])\n\nmnist_train_dataset = torchvision.datasets.MNIST(\n    root=\"./data\",\n    train=True,\n    download=True,\n    transform=mnist_transform\n)\n\nmnist_test_dataset = torchvision.datasets.MNIST(\n    root=\"./data\",\n    train=False,\n    download=True,\n    transform=mnist_transform\n)","metadata":{"id":"BhGaW6EIAQhz","outputId":"0e9d3b68-0f93-4553-ebbf-b9b8785183bc","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:09.453769Z","iopub.execute_input":"2024-12-31T22:31:09.454320Z","iopub.status.idle":"2024-12-31T22:31:14.426089Z","shell.execute_reply.started":"2024-12-31T22:31:09.454286Z","shell.execute_reply":"2024-12-31T22:31:14.425172Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### CIFAR-10","metadata":{"id":"vNfURcEonkeb"}},{"cell_type":"code","source":"normalize = transforms.Normalize(\n    mean=[0.4914, 0.4822, 0.4465],\n    std=[0.2023, 0.1994, 0.2010]\n)\n\ncifar10_transform_train = transforms.Compose([\n    transforms.RandomCrop(32, padding=4),\n    transforms.RandomHorizontalFlip(),\n    transforms.ToTensor(),\n    normalize\n])\n\ncifar10_transform_test = transforms.Compose([\n    transforms.ToTensor(),\n    normalize\n])\n\ncifar10_train_dataset = torchvision.datasets.CIFAR10(\n    root=\"./data\",\n    train=True,\n    download=True,\n    transform=cifar10_transform_train\n)\n\ncifar10_test_dataset = torchvision.datasets.CIFAR10(\n    root=\"./data\",\n    train=False,\n    download=True,\n    transform=cifar10_transform_test\n)","metadata":{"id":"9jn3AO3DnPB-","outputId":"44e90d71-2577-47e6-ed37-7e6c8a8fe7c7","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:14.427860Z","iopub.execute_input":"2024-12-31T22:31:14.428252Z","iopub.status.idle":"2024-12-31T22:31:22.920648Z","shell.execute_reply.started":"2024-12-31T22:31:14.428212Z","shell.execute_reply":"2024-12-31T22:31:22.919775Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### ImageNet","metadata":{}},{"cell_type":"code","source":"ImageNet_DATAPATH = \"/kaggle/input/imagenet-object-localization-challenge\"\nImageNet_CLASSPATH = \"/kaggle/working/imagenet-classes\"\n\n! wget -nc -P {ImageNet_CLASSPATH} https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json\n! wget -nc -P {ImageNet_CLASSPATH} https://gist.githubusercontent.com/paulgavrikov/3af1efe6f3dff63f47d48b91bb1bca6b/raw/00bad6903b5e4f84c7796b982b72e2e617e5fde1/ILSVRC2012_val_labels.json","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class ImageNetKaggle(torch.utils.data.Dataset):\n    def __init__(self, data_path, class_path, split, transform=None):\n        self.samples = []\n        self.targets = []\n        self.transform = transform\n        self.syn_to_class = {}\n        with open(os.path.join(class_path, \"imagenet_class_index.json\"), \"rb\") as f:\n            json_file = json.load(f)\n            for class_id, v in json_file.items():\n                self.syn_to_class[v[0]] = int(class_id)\n        with open(os.path.join(class_path, \"ILSVRC2012_val_labels.json\"), \"rb\") as f:\n            self.val_to_syn = json.load(f)\n        samples_dir = os.path.join(data_path, \"ILSVRC/Data/CLS-LOC\", split)\n        for entry in os.listdir(samples_dir):\n            if split == \"train\":\n                syn_id = entry\n                target = self.syn_to_class[syn_id]\n                syn_folder = os.path.join(samples_dir, syn_id)\n                for sample in os.listdir(syn_folder):\n                    sample_path = os.path.join(syn_folder, sample)\n                    self.samples.append(sample_path)\n                    self.targets.append(target)\n            elif split == \"val\":\n                syn_id = self.val_to_syn[entry]\n                target = self.syn_to_class[syn_id]\n                sample_path = os.path.join(samples_dir, entry)\n                self.samples.append(sample_path)\n                self.targets.append(target)\n\n    def __len__(self):\n        return len(self.samples)\n\n    def __getitem__(self, idx):\n        x = Image.open(self.samples[idx]).convert(\"RGB\")\n        if self.transform:\n            x = self.transform(x)\n        return x, self.targets[idx]","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"imagenet_mean = (0.485, 0.456, 0.406)\nimagenet_std = (0.229, 0.224, 0.225)\n\ntransform = transforms.Compose(\n    [\n        transforms.Resize(256),\n        transforms.CenterCrop(224),\n        transforms.ToTensor(),\n        transforms.Normalize(imagenet_mean, imagenet_std),\n    ]\n)\n\nimagenet_train_dataset = ImageNetKaggle(\n    data_path=ImageNet_DATAPATH,\n    class_path=ImageNet_CLASSPATH,\n    split=\"train\",\n    transform=transform\n)\n\nimagenet_test_dataset = ImageNetKaggle(\n    data_path=ImageNet_DATAPATH,\n    class_path=ImageNet_CLASSPATH,\n    split=\"val\",\n    transform=transform\n)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Models","metadata":{"id":"0HzChSLJnnkN"}},{"cell_type":"markdown","source":"### LeNet-300-100-10 (MNIST)","metadata":{"id":"DLBim5KLnr0F"}},{"cell_type":"code","source":"class LeNet300(nn.Module):\n    def __init__(self, num_classes=10):\n        super(LeNet300, self).__init__()\n        self.fc1 = nn.Linear(28 * 28, 300)\n        self.fc2 = nn.Linear(300, 100)\n        self.fc3 = nn.Linear(100, num_classes)\n\n    def forward(self, x):\n        x = x.view(-1, 28 * 28)  # Flatten the input\n        x = torch.nn.functional.relu(self.fc1(x))\n        x = torch.nn.functional.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x","metadata":{"id":"MmYZIXDfcvPh","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:22.926872Z","iopub.execute_input":"2024-12-31T22:31:22.927153Z","iopub.status.idle":"2024-12-31T22:31:22.943807Z","shell.execute_reply.started":"2024-12-31T22:31:22.927094Z","shell.execute_reply":"2024-12-31T22:31:22.942863Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### LeNet","metadata":{"id":"CNpj-u2wno_a"}},{"cell_type":"code","source":"class LeNet(nn.Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1   = nn.Linear(16*5*5, 120)\n        self.fc2   = nn.Linear(120, 84)\n        self.fc3   = nn.Linear(84, 10)\n\n    def forward(self, x):\n        out = F.relu(self.conv1(x))\n        out = F.max_pool2d(out, 2)\n        out = F.relu(self.conv2(out))\n        out = F.max_pool2d(out, 2)\n        out = out.view(out.size(0), -1)\n        out = F.relu(self.fc1(out))\n        out = F.relu(self.fc2(out))\n        out = self.fc3(out)\n        return out","metadata":{"id":"DN662fmiAk-6","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:22.944968Z","iopub.execute_input":"2024-12-31T22:31:22.945343Z","iopub.status.idle":"2024-12-31T22:31:22.955500Z","shell.execute_reply.started":"2024-12-31T22:31:22.945293Z","shell.execute_reply":"2024-12-31T22:31:22.954588Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### VGG","metadata":{"id":"wjoNlnui2mzp"}},{"cell_type":"code","source":"class VGG(nn.Module):\n    '''\n    VGG model\n    '''\n    def __init__(self, features):\n        super(VGG, self).__init__()\n        self.features = features\n        self.classifier = nn.Sequential(\n            nn.Dropout(),\n            nn.Linear(512, 512),\n            nn.ReLU(True),\n            nn.Dropout(),\n            nn.Linear(512, 512),\n            nn.ReLU(True),\n            nn.Linear(512, 10),\n        )\n         # Initialize weights\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n                m.bias.data.zero_()\n\n\n    def forward(self, x):\n        x = self.features(x)\n        x = x.view(x.size(0), -1)\n        x = self.classifier(x)\n        return x\n\n\ndef make_vgg_layers(cfg, batch_norm=False):\n    layers = []\n    in_channels = 3\n    for v in cfg:\n        if v == 'M':\n            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n        else:\n            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)\n            if batch_norm:\n                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]\n            else:\n                layers += [conv2d, nn.ReLU(inplace=True)]\n            in_channels = v\n    return nn.Sequential(*layers)\n\n\nvgg_cfg = {\n    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],\n    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],\n    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],\n    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',\n          512, 512, 512, 512, 'M'],\n}\n\n\ndef vgg11():\n    \"\"\"VGG 11-layer model (configuration \"A\")\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['A']))\n\n\ndef vgg11_bn():\n    \"\"\"VGG 11-layer model (configuration \"A\") with batch normalization\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['A'], batch_norm=True))\n\n\ndef vgg13():\n    \"\"\"VGG 13-layer model (configuration \"B\")\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['B']))\n\n\ndef vgg13_bn():\n    \"\"\"VGG 13-layer model (configuration \"B\") with batch normalization\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['B'], batch_norm=True))\n\n\ndef vgg16():\n    \"\"\"VGG 16-layer model (configuration \"D\")\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['D']))\n\n\ndef vgg16_bn():\n    \"\"\"VGG 16-layer model (configuration \"D\") with batch normalization\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['D'], batch_norm=True))\n\n\ndef vgg19():\n    \"\"\"VGG 19-layer model (configuration \"E\")\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['E']))\n\n\ndef vgg19_bn():\n    \"\"\"VGG 19-layer model (configuration 'E') with batch normalization\"\"\"\n    return VGG(make_vgg_layers(vgg_cfg['E'], batch_norm=True))","metadata":{"id":"oVq0L4cW2pGv","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:22.956682Z","iopub.execute_input":"2024-12-31T22:31:22.957013Z","iopub.status.idle":"2024-12-31T22:31:22.972994Z","shell.execute_reply.started":"2024-12-31T22:31:22.956987Z","shell.execute_reply":"2024-12-31T22:31:22.972170Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### MobileNetV2","metadata":{}},{"cell_type":"code","source":"'''MobileNetV2 in PyTorch.\n\nSee the paper \"Inverted Residuals and Linear Bottlenecks:\nMobile Networks for Classification, Detection and Segmentation\" for more details.\n'''\n\nclass MobileNetV2_Block(nn.Module):\n    '''expand + depthwise + pointwise'''\n    def __init__(self, in_planes, out_planes, expansion, stride):\n        super(MobileNetV2_Block, self).__init__()\n        self.stride = stride\n\n        planes = expansion * in_planes\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)\n        self.bn3 = nn.BatchNorm2d(out_planes)\n\n        self.shortcut = nn.Sequential()\n        if stride == 1 and in_planes != out_planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),\n                nn.BatchNorm2d(out_planes),\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = F.relu(self.bn2(self.conv2(out)))\n        out = self.bn3(self.conv3(out))\n        out = out + self.shortcut(x) if self.stride==1 else out\n        return out\n\n\nclass MobileNetV2(nn.Module):\n    # (expansion, out_planes, num_blocks, stride)\n    cfg = [(1,  16, 1, 1),\n           (6,  24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10\n           (6,  32, 3, 2),\n           (6,  64, 4, 2),\n           (6,  96, 3, 1),\n           (6, 160, 3, 2),\n           (6, 320, 1, 1)]\n\n    def __init__(self, num_classes=10):\n        super(MobileNetV2, self).__init__()\n        # NOTE: change conv1 stride 2 -> 1 for CIFAR10\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(32)\n        self.layers = self._make_layers(in_planes=32)\n        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)\n        self.bn2 = nn.BatchNorm2d(1280)\n        self.linear = nn.Linear(1280, num_classes)\n\n    def _make_layers(self, in_planes):\n        layers = []\n        for expansion, out_planes, num_blocks, stride in self.cfg:\n            strides = [stride] + [1]*(num_blocks-1)\n            for stride in strides:\n                layers.append(MobileNetV2_Block(in_planes, out_planes, expansion, stride))\n                in_planes = out_planes\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.layers(out)\n        out = F.relu(self.bn2(self.conv2(out)))\n        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10\n        out = F.avg_pool2d(out, 4)\n        out = out.view(out.size(0), -1)\n        out = self.linear(out)\n        return out","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:26.988630Z","iopub.execute_input":"2024-12-31T22:31:26.989424Z","iopub.status.idle":"2024-12-31T22:31:27.001146Z","shell.execute_reply.started":"2024-12-31T22:31:26.989390Z","shell.execute_reply":"2024-12-31T22:31:27.000222Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### ResNet","metadata":{}},{"cell_type":"code","source":"'''ResNet in PyTorch.\nReference:\n[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n    Deep Residual Learning for Image Recognition. arXiv:1512.03385\n'''\n\nclass ResNet_BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_planes, planes, stride=1):\n        super(ResNet_BasicBlock, self).__init__()\n        self.conv1 = nn.Conv2d(\n            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                               stride=1, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion*planes,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(self.expansion*planes)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.bn2(self.conv2(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass ResNet_Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, in_planes, planes, stride=1):\n        super(ResNet_Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                               stride=stride, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, self.expansion *\n                               planes, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion*planes,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(self.expansion*planes)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = F.relu(self.bn2(self.conv2(out)))\n        out = self.bn3(self.conv3(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(self, block, num_blocks, num_classes=10):\n        super(ResNet, self).__init__()\n        self.in_planes = 64\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,\n                               stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n        self.linear = nn.Linear(512*block.expansion, num_classes)\n\n    def _make_layer(self, block, planes, num_blocks, stride):\n        strides = [stride] + [1]*(num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride))\n            self.in_planes = planes * block.expansion\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = self.layer4(out)\n        out = F.avg_pool2d(out, 4)\n        out = out.view(out.size(0), -1)\n        out = self.linear(out)\n        return out\n\n\ndef ResNet18():\n    return ResNet(ResNet_BasicBlock, [2, 2, 2, 2])\n\n\ndef ResNet34():\n    return ResNet(ResNet_BasicBlock, [3, 4, 6, 3])\n\n\ndef ResNet50():\n    return ResNet(ResNet_Bottleneck, [3, 4, 6, 3])\n\n\ndef ResNet101():\n    return ResNet(ResNet_Bottleneck, [3, 4, 23, 3])\n\n\ndef ResNet152():\n    return ResNet(ResNet_Bottleneck, [3, 8, 36, 3])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:30.014876Z","iopub.execute_input":"2024-12-31T22:31:30.015531Z","iopub.status.idle":"2024-12-31T22:31:30.031727Z","shell.execute_reply.started":"2024-12-31T22:31:30.015495Z","shell.execute_reply":"2024-12-31T22:31:30.030808Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Inception: GoogleNet","metadata":{}},{"cell_type":"code","source":"'''GoogLeNet with PyTorch.'''\n\nclass Inception(nn.Module):\n    def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):\n        super(Inception, self).__init__()\n        # 1x1 conv branch\n        self.b1 = nn.Sequential(\n            nn.Conv2d(in_planes, n1x1, kernel_size=1),\n            nn.BatchNorm2d(n1x1),\n            nn.ReLU(True),\n        )\n\n        # 1x1 conv -> 3x3 conv branch\n        self.b2 = nn.Sequential(\n            nn.Conv2d(in_planes, n3x3red, kernel_size=1),\n            nn.BatchNorm2d(n3x3red),\n            nn.ReLU(True),\n            nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),\n            nn.BatchNorm2d(n3x3),\n            nn.ReLU(True),\n        )\n\n        # 1x1 conv -> 5x5 conv branch\n        self.b3 = nn.Sequential(\n            nn.Conv2d(in_planes, n5x5red, kernel_size=1),\n            nn.BatchNorm2d(n5x5red),\n            nn.ReLU(True),\n            nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),\n            nn.BatchNorm2d(n5x5),\n            nn.ReLU(True),\n            nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),\n            nn.BatchNorm2d(n5x5),\n            nn.ReLU(True),\n        )\n\n        # 3x3 pool -> 1x1 conv branch\n        self.b4 = nn.Sequential(\n            nn.MaxPool2d(3, stride=1, padding=1),\n            nn.Conv2d(in_planes, pool_planes, kernel_size=1),\n            nn.BatchNorm2d(pool_planes),\n            nn.ReLU(True),\n        )\n\n    def forward(self, x):\n        y1 = self.b1(x)\n        y2 = self.b2(x)\n        y3 = self.b3(x)\n        y4 = self.b4(x)\n        return torch.cat([y1,y2,y3,y4], 1)\n\n\nclass GoogLeNet(nn.Module):\n    def __init__(self):\n        super(GoogLeNet, self).__init__()\n        self.pre_layers = nn.Sequential(\n            nn.Conv2d(3, 192, kernel_size=3, padding=1),\n            nn.BatchNorm2d(192),\n            nn.ReLU(True),\n        )\n\n        self.a3 = Inception(192,  64,  96, 128, 16, 32, 32)\n        self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)\n\n        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)\n\n        self.a4 = Inception(480, 192,  96, 208, 16,  48,  64)\n        self.b4 = Inception(512, 160, 112, 224, 24,  64,  64)\n        self.c4 = Inception(512, 128, 128, 256, 24,  64,  64)\n        self.d4 = Inception(512, 112, 144, 288, 32,  64,  64)\n        self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)\n\n        self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)\n        self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)\n\n        self.avgpool = nn.AvgPool2d(8, stride=1)\n        self.linear = nn.Linear(1024, 10)\n\n    def forward(self, x):\n        out = self.pre_layers(x)\n        out = self.a3(out)\n        out = self.b3(out)\n        out = self.maxpool(out)\n        out = self.a4(out)\n        out = self.b4(out)\n        out = self.c4(out)\n        out = self.d4(out)\n        out = self.e4(out)\n        out = self.maxpool(out)\n        out = self.a5(out)\n        out = self.b5(out)\n        out = self.avgpool(out)\n        out = out.view(out.size(0), -1)\n        out = self.linear(out)\n        return out","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:48.092794Z","iopub.execute_input":"2024-12-31T22:31:48.093150Z","iopub.status.idle":"2024-12-31T22:31:48.108316Z","shell.execute_reply.started":"2024-12-31T22:31:48.093091Z","shell.execute_reply":"2024-12-31T22:31:48.107125Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### EfficientNet","metadata":{}},{"cell_type":"code","source":"'''EfficientNet in PyTorch.\n\nPaper: \"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks\".\n\nReference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py\n'''\n\ndef swish(x):\n    return x * x.sigmoid()\n\n\ndef drop_connect(x, drop_ratio):\n    keep_ratio = 1.0 - drop_ratio\n    mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)\n    mask.bernoulli_(keep_ratio)\n    x.div_(keep_ratio)\n    x.mul_(mask)\n    return x\n\n\nclass SE(nn.Module):\n    '''Squeeze-and-Excitation block with Swish.'''\n\n    def __init__(self, in_channels, se_channels):\n        super(SE, self).__init__()\n        self.se1 = nn.Conv2d(in_channels, se_channels,\n                             kernel_size=1, bias=True)\n        self.se2 = nn.Conv2d(se_channels, in_channels,\n                             kernel_size=1, bias=True)\n\n    def forward(self, x):\n        out = F.adaptive_avg_pool2d(x, (1, 1))\n        out = swish(self.se1(out))\n        out = self.se2(out).sigmoid()\n        out = x * out\n        return out\n\n\nclass Efficient_Block(nn.Module):\n    '''expansion + depthwise + pointwise + squeeze-excitation'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 expand_ratio=1,\n                 se_ratio=0.,\n                 drop_rate=0.):\n        super(Efficient_Block, self).__init__()\n        self.stride = stride\n        self.drop_rate = drop_rate\n        self.expand_ratio = expand_ratio\n\n        # Expansion\n        channels = expand_ratio * in_channels\n        self.conv1 = nn.Conv2d(in_channels,\n                               channels,\n                               kernel_size=1,\n                               stride=1,\n                               padding=0,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(channels)\n\n        # Depthwise conv\n        self.conv2 = nn.Conv2d(channels,\n                               channels,\n                               kernel_size=kernel_size,\n                               stride=stride,\n                               padding=(1 if kernel_size == 3 else 2),\n                               groups=channels,\n                               bias=False)\n        self.bn2 = nn.BatchNorm2d(channels)\n\n        # SE layers\n        se_channels = int(in_channels * se_ratio)\n        self.se = SE(channels, se_channels)\n\n        # Output\n        self.conv3 = nn.Conv2d(channels,\n                               out_channels,\n                               kernel_size=1,\n                               stride=1,\n                               padding=0,\n                               bias=False)\n        self.bn3 = nn.BatchNorm2d(out_channels)\n\n        # Skip connection if in and out shapes are the same (MV-V2 style)\n        self.has_skip = (stride == 1) and (in_channels == out_channels)\n\n    def forward(self, x):\n        out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x)))\n        out = swish(self.bn2(self.conv2(out)))\n        out = self.se(out)\n        out = self.bn3(self.conv3(out))\n        if self.has_skip:\n            if self.training and self.drop_rate > 0:\n                out = drop_connect(out, self.drop_rate)\n            out = out + x\n        return out\n\n\nclass EfficientNet(nn.Module):\n    def __init__(self, cfg, num_classes=10):\n        super(EfficientNet, self).__init__()\n        self.cfg = cfg\n        self.conv1 = nn.Conv2d(3,\n                               32,\n                               kernel_size=3,\n                               stride=1,\n                               padding=1,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(32)\n        self.layers = self._make_layers(in_channels=32)\n        self.linear = nn.Linear(cfg['out_channels'][-1], num_classes)\n\n    def _make_layers(self, in_channels):\n        layers = []\n        cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size',\n                                     'stride']]\n        b = 0\n        blocks = sum(self.cfg['num_blocks'])\n        for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg):\n            strides = [stride] + [1] * (num_blocks - 1)\n            for stride in strides:\n                drop_rate = self.cfg['drop_connect_rate'] * b / blocks\n                layers.append(\n                    Efficient_Block(in_channels,\n                          out_channels,\n                          kernel_size,\n                          stride,\n                          expansion,\n                          se_ratio=0.25,\n                          drop_rate=drop_rate))\n                in_channels = out_channels\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = swish(self.bn1(self.conv1(x)))\n        out = self.layers(out)\n        out = F.adaptive_avg_pool2d(out, 1)\n        out = out.view(out.size(0), -1)\n        dropout_rate = self.cfg['dropout_rate']\n        if self.training and dropout_rate > 0:\n            out = F.dropout(out, p=dropout_rate)\n        out = self.linear(out)\n        return out\n\n\ndef EfficientNetB0():\n    cfg = {\n        'num_blocks': [1, 2, 2, 3, 3, 4, 1],\n        'expansion': [1, 6, 6, 6, 6, 6, 6],\n        'out_channels': [16, 24, 40, 80, 112, 192, 320],\n        'kernel_size': [3, 3, 5, 3, 5, 5, 3],\n        'stride': [1, 2, 2, 2, 1, 2, 1],\n        'dropout_rate': 0.2,\n        'drop_connect_rate': 0.2,\n    }\n    return EfficientNet(cfg)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:52.144087Z","iopub.execute_input":"2024-12-31T22:31:52.144462Z","iopub.status.idle":"2024-12-31T22:31:52.163147Z","shell.execute_reply.started":"2024-12-31T22:31:52.144430Z","shell.execute_reply":"2024-12-31T22:31:52.162241Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### ShuffleNetV2","metadata":{}},{"cell_type":"code","source":"'''ShuffleNetV2 in PyTorch.\n\nSee the paper \"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design\" for more details.\n'''\n\nclass ShuffleBlock(nn.Module):\n    def __init__(self, groups=2):\n        super(ShuffleBlock, self).__init__()\n        self.groups = groups\n\n    def forward(self, x):\n        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''\n        N, C, H, W = x.size()\n        g = self.groups\n        return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)\n\n\nclass SplitBlock(nn.Module):\n    def __init__(self, ratio):\n        super(SplitBlock, self).__init__()\n        self.ratio = ratio\n\n    def forward(self, x):\n        c = int(x.size(1) * self.ratio)\n        return x[:, :c, :, :], x[:, c:, :, :]\n\n\nclass Shuffle_BasicBlock(nn.Module):\n    def __init__(self, in_channels, split_ratio=0.5):\n        super(Shuffle_BasicBlock, self).__init__()\n        self.split = SplitBlock(split_ratio)\n        in_channels = int(in_channels * split_ratio)\n        self.conv1 = nn.Conv2d(in_channels, in_channels,\n                               kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(in_channels)\n        self.conv2 = nn.Conv2d(in_channels, in_channels,\n                               kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)\n        self.bn2 = nn.BatchNorm2d(in_channels)\n        self.conv3 = nn.Conv2d(in_channels, in_channels,\n                               kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(in_channels)\n        self.shuffle = ShuffleBlock()\n\n    def forward(self, x):\n        x1, x2 = self.split(x)\n        out = F.relu(self.bn1(self.conv1(x2)))\n        out = self.bn2(self.conv2(out))\n        out = F.relu(self.bn3(self.conv3(out)))\n        out = torch.cat([x1, out], 1)\n        out = self.shuffle(out)\n        return out\n\n\nclass Shuffle_DownBlock(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(Shuffle_DownBlock, self).__init__()\n        mid_channels = out_channels // 2\n        # left\n        self.conv1 = nn.Conv2d(in_channels, in_channels,\n                               kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)\n        self.bn1 = nn.BatchNorm2d(in_channels)\n        self.conv2 = nn.Conv2d(in_channels, mid_channels,\n                               kernel_size=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(mid_channels)\n        # right\n        self.conv3 = nn.Conv2d(in_channels, mid_channels,\n                               kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(mid_channels)\n        self.conv4 = nn.Conv2d(mid_channels, mid_channels,\n                               kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)\n        self.bn4 = nn.BatchNorm2d(mid_channels)\n        self.conv5 = nn.Conv2d(mid_channels, mid_channels,\n                               kernel_size=1, bias=False)\n        self.bn5 = nn.BatchNorm2d(mid_channels)\n\n        self.shuffle = ShuffleBlock()\n\n    def forward(self, x):\n        # left\n        out1 = self.bn1(self.conv1(x))\n        out1 = F.relu(self.bn2(self.conv2(out1)))\n        # right\n        out2 = F.relu(self.bn3(self.conv3(x)))\n        out2 = self.bn4(self.conv4(out2))\n        out2 = F.relu(self.bn5(self.conv5(out2)))\n        # concat\n        out = torch.cat([out1, out2], 1)\n        out = self.shuffle(out)\n        return out\n\n\nclass ShuffleNetV2(nn.Module):\n    def __init__(self, net_size):\n        super(ShuffleNetV2, self).__init__()\n        out_channels = Shuffle_configs[net_size]['out_channels']\n        num_blocks = Shuffle_configs[net_size]['num_blocks']\n\n        self.conv1 = nn.Conv2d(3, 24, kernel_size=3,\n                               stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(24)\n        self.in_channels = 24\n        self.layer1 = self._make_layer(out_channels[0], num_blocks[0])\n        self.layer2 = self._make_layer(out_channels[1], num_blocks[1])\n        self.layer3 = self._make_layer(out_channels[2], num_blocks[2])\n        self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],\n                               kernel_size=1, stride=1, padding=0, bias=False)\n        self.bn2 = nn.BatchNorm2d(out_channels[3])\n        self.linear = nn.Linear(out_channels[3], 10)\n\n    def _make_layer(self, out_channels, num_blocks):\n        layers = [Shuffle_DownBlock(self.in_channels, out_channels)]\n        for i in range(num_blocks):\n            layers.append(Shuffle_BasicBlock(out_channels))\n            self.in_channels = out_channels\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        # out = F.max_pool2d(out, 3, stride=2, padding=1)\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = F.relu(self.bn2(self.conv2(out)))\n        out = F.avg_pool2d(out, 4)\n        out = out.view(out.size(0), -1)\n        out = self.linear(out)\n        return out\n\n\nShuffle_configs = {\n    0.5: {\n        'out_channels': (48, 96, 192, 1024),\n        'num_blocks': (3, 7, 3)\n    },\n\n    1: {\n        'out_channels': (116, 232, 464, 1024),\n        'num_blocks': (3, 7, 3)\n    },\n    1.5: {\n        'out_channels': (176, 352, 704, 1024),\n        'num_blocks': (3, 7, 3)\n    },\n    2: {\n        'out_channels': (224, 488, 976, 2048),\n        'num_blocks': (3, 7, 3)\n    }\n}","metadata":{"_kg_hide-input":true,"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:31:59.063388Z","iopub.execute_input":"2024-12-31T22:31:59.063725Z","iopub.status.idle":"2024-12-31T22:31:59.082453Z","shell.execute_reply.started":"2024-12-31T22:31:59.063694Z","shell.execute_reply":"2024-12-31T22:31:59.081526Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### DenseNet","metadata":{}},{"cell_type":"code","source":"class DenseNet_Bottleneck(nn.Module):\n    def __init__(self, in_planes, growth_rate):\n        super(DenseNet_Bottleneck, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(4*growth_rate)\n        self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = self.conv2(F.relu(self.bn2(out)))\n        out = torch.cat([out,x], 1)\n        return out\n\n\nclass DenseNet_Transition(nn.Module):\n    def __init__(self, in_planes, out_planes):\n        super(DenseNet_Transition, self).__init__()\n        self.bn = nn.BatchNorm2d(in_planes)\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)\n\n    def forward(self, x):\n        out = self.conv(F.relu(self.bn(x)))\n        out = F.avg_pool2d(out, 2)\n        return out\n\n\nclass DenseNet(nn.Module):\n    def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):\n        super(DenseNet, self).__init__()\n        self.growth_rate = growth_rate\n\n        num_planes = 2*growth_rate\n        self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)\n\n        self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])\n        num_planes += nblocks[0]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans1 = DenseNet_Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])\n        num_planes += nblocks[1]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans2 = DenseNet_Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])\n        num_planes += nblocks[2]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans3 = DenseNet_Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])\n        num_planes += nblocks[3]*growth_rate\n\n        self.bn = nn.BatchNorm2d(num_planes)\n        self.linear = nn.Linear(num_planes, num_classes)\n\n    def _make_dense_layers(self, block, in_planes, nblock):\n        layers = []\n        for i in range(nblock):\n            layers.append(block(in_planes, self.growth_rate))\n            in_planes += self.growth_rate\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.trans1(self.dense1(out))\n        out = self.trans2(self.dense2(out))\n        out = self.trans3(self.dense3(out))\n        out = self.dense4(out)\n        out = F.avg_pool2d(F.relu(self.bn(out)), 4)\n        out = out.view(out.size(0), -1)\n        out = self.linear(out)\n        return out\n\ndef DenseNet121():\n    return DenseNet(DenseNet_Bottleneck, [6,12,24,16], growth_rate=32)\n\ndef DenseNet169():\n    return DenseNet(DenseNet_Bottleneck, [6,12,32,32], growth_rate=32)\n\ndef DenseNet201():\n    return DenseNet(DenseNet_Bottleneck, [6,12,48,32], growth_rate=32)\n\ndef DenseNet161():\n    return DenseNet(DenseNet_Bottleneck, [6,12,36,24], growth_rate=48)\n\ndef densenet_cifar():\n    return DenseNet(DenseNet_Bottleneck, [6,12,24,16], growth_rate=12)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:53:01.991969Z","iopub.execute_input":"2024-12-31T22:53:01.992408Z","iopub.status.idle":"2024-12-31T22:53:02.009939Z","shell.execute_reply.started":"2024-12-31T22:53:01.992370Z","shell.execute_reply":"2024-12-31T22:53:02.009004Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### InceptionV3","metadata":{}},{"cell_type":"code","source":"from collections import namedtuple\n\n_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])\n\n\nclass Inception3(nn.Module):\n    def __init__(self, num_classes=10, aux_logits=True, transform_input=True):\n        super(Inception3, self).__init__()\n        self.aux_logits = aux_logits\n        self.transform_input = transform_input\n        self.Conv2d_4a_3x3 = BasicConv2d(3, 32, kernel_size=3,padding=1)\n        self.Mixed_5b = InceptionA(32, pool_features=8)\n        self.Mixed_5c = InceptionA(64, pool_features=72)\n        self.Mixed_6a = InceptionB(128)\n        self.Mixed_6b = InceptionC(256, channels_7x7=64)\n        if aux_logits:\n            self.AuxLogits = InceptionAux(512, num_classes)\n        self.Mixed_7a = InceptionD(512)\n        self.fc = nn.Linear(768, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):\n                import scipy.stats as stats\n                stddev = m.stddev if hasattr(m, 'stddev') else 0.1\n                X = stats.truncnorm(-2, 2, scale=stddev)\n                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)\n                values = values.view(m.weight.size())\n                with torch.no_grad():\n                    m.weight.copy_(values)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        global aux\n        x = self.Conv2d_4a_3x3(x)\n        x = self.Mixed_5b(x)\n        x = self.Mixed_5c(x)\n        x = self.Mixed_6a(x)\n        x = self.Mixed_6b(x)\n        if self.training and self.aux_logits:\n            aux = self.AuxLogits(x)\n        x = self.Mixed_7a(x)\n        x = F.adaptive_avg_pool2d(x, (1, 1))\n        x = F.dropout(x, training=self.training)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n        if self.training and self.aux_logits:\n            return _InceptionOutputs(x, aux)\n        return x\n\n\nclass InceptionA(nn.Module):\n  def __init__(self, in_channels, pool_features):\n    super(InceptionA, self).__init__()\n    self.branch1x1 = BasicConv2d(in_channels, 8, kernel_size=1)\n\n    self.branch5x5_1 = BasicConv2d(in_channels, 8, kernel_size=1)\n    self.branch5x5_2 = BasicConv2d(8, 16, kernel_size=5, padding=2)\n\n    self.branch3x3dbl_1 = BasicConv2d(in_channels, 8, kernel_size=1)\n    self.branch3x3dbl_2 = BasicConv2d(8, 16, kernel_size=3, padding=1)\n    self.branch3x3dbl_3 = BasicConv2d(16, 32, kernel_size=3, padding=1)\n\n    self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)\n\n  def forward(self, x):\n    branch1x1 = self.branch1x1(x)\n\n    branch5x5 = self.branch5x5_1(x)\n    branch5x5 = self.branch5x5_2(branch5x5)\n\n    branch3x3dbl = self.branch3x3dbl_1(x)\n    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)\n\n    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n    branch_pool = self.branch_pool(branch_pool)\n\n    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]\n    return torch.cat(outputs, 1)\n\n\nclass InceptionB(nn.Module):\n  def __init__(self, in_channels):\n    super(InceptionB, self).__init__()\n    self.branch3x3 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)\n\n    self.branch3x3dbl_1 = BasicConv2d(in_channels, 32, kernel_size=1)\n    self.branch3x3dbl_2 = BasicConv2d(32, 64, kernel_size=3, padding=1)\n    self.branch3x3dbl_3 = BasicConv2d(64, 96, kernel_size=3, stride=2)\n\n  def forward(self, x):\n    branch3x3 = self.branch3x3(x)\n\n    branch3x3dbl = self.branch3x3dbl_1(x)\n    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)\n\n    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)\n\n    outputs = [branch3x3, branch3x3dbl, branch_pool]\n    return torch.cat(outputs, 1)\n    \n    \nclass InceptionC(nn.Module):\n  def __init__(self, in_channels, channels_7x7):\n    super(InceptionC, self).__init__()\n    self.branch1x1 = BasicConv2d(in_channels, 128, kernel_size=1)\n\n    c7 = channels_7x7\n    self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)\n    self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))\n    self.branch7x7_3 = BasicConv2d(c7, 128, kernel_size=(7, 1), padding=(3, 0))\n\n    self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)\n    self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))\n    self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))\n    self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))\n    self.branch7x7dbl_5 = BasicConv2d(c7, 128, kernel_size=(1, 7), padding=(0, 3))\n\n    self.branch_pool = BasicConv2d(in_channels, 128, kernel_size=1)\n\n  def forward(self, x):\n    branch1x1 = self.branch1x1(x)\n\n    branch7x7 = self.branch7x7_1(x)\n    branch7x7 = self.branch7x7_2(branch7x7)\n    branch7x7 = self.branch7x7_3(branch7x7)\n\n    branch7x7dbl = self.branch7x7dbl_1(x)\n    branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)\n    branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)\n    branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)\n    branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)\n\n    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n    branch_pool = self.branch_pool(branch_pool)\n\n    outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]\n    return torch.cat(outputs, 1)\n    \n    \nclass InceptionD(nn.Module):\n  def __init__(self, in_channels):\n    super(InceptionD, self).__init__()\n    self.branch3x3_1 = BasicConv2d(in_channels, 32, kernel_size=1)\n    self.branch3x3_2 = BasicConv2d(32,64, kernel_size=3, stride=2)\n\n    self.branch7x7x3_1 = BasicConv2d(in_channels, 32, kernel_size=1)\n    self.branch7x7x3_2 = BasicConv2d(32,64, kernel_size=(1, 7), padding=(0, 3))\n    self.branch7x7x3_3 = BasicConv2d(64, 128, kernel_size=(7, 1), padding=(3, 0))\n    self.branch7x7x3_4 = BasicConv2d(128,192, kernel_size=3, stride=2)\n\n  def forward(self, x):\n    branch3x3 = self.branch3x3_1(x)\n    branch3x3 = self.branch3x3_2(branch3x3)\n\n    branch7x7x3 = self.branch7x7x3_1(x)\n    branch7x7x3 = self.branch7x7x3_2(branch7x7x3)\n    branch7x7x3 = self.branch7x7x3_3(branch7x7x3)\n    branch7x7x3 = self.branch7x7x3_4(branch7x7x3)\n\n    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)\n    outputs = [branch3x3, branch7x7x3, branch_pool]\n    return torch.cat(outputs, 1)\n    \n\nclass InceptionAux(nn.Module):\n  def __init__(self, in_channels, num_classes):\n    super(InceptionAux, self).__init__()\n    self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)\n    self.conv1 = BasicConv2d(128, 512, kernel_size=5)\n    self.conv1.stddev = 0.01\n    self.fc = nn.Linear(512, num_classes)\n    self.fc.stddev = 0.001\n\n  def forward(self, x):\n    x = F.avg_pool2d(x, kernel_size=3, stride=3)\n    x = self.conv0(x)\n    x = self.conv1(x)\n    x = F.adaptive_avg_pool2d(x, (1, 1))\n    x = torch.flatten(x, 1)\n    x = self.fc(x)\n    return x\n    \n    \nclass BasicConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, **kwargs):\n        super(BasicConv2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn(x)\n        return F.relu(x, inplace=True)\n\n\ndef inception_v3():\n    return Inception3(num_classes=10, aux_logits=True, transform_input=False)\n\n\n# Replace these lines in train_one_epoch\n# outputs, aux_outputs = model(inputs)\n# loss1 = criterion(outputs, labels)\n# loss2 = criterion(aux_outputs, labels)\n# loss = loss1 + 0.4*loss2","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Utility","metadata":{"id":"TDrSyPQPAWSg"}},{"cell_type":"markdown","source":"### Sparsity Calculator","metadata":{"id":"JSIfYY4mAfWL"}},{"cell_type":"code","source":"def calculate_sparsity(model):\n    total_zeros = 0\n    total_elements = 0\n    for param in model.parameters():\n        zero_elements = torch.sum(param == 0).item()\n        total_zeros += zero_elements\n        total_elements += param.numel()\n    sparsity = total_zeros / total_elements\n    return sparsity * 100","metadata":{"id":"uW5Iw8zOkoZG","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:11.546260Z","iopub.execute_input":"2024-12-31T22:32:11.547146Z","iopub.status.idle":"2024-12-31T22:32:11.552929Z","shell.execute_reply.started":"2024-12-31T22:32:11.547087Z","shell.execute_reply":"2024-12-31T22:32:11.551986Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Metrics Plotter","metadata":{"id":"jp7VvLlSAimx"}},{"cell_type":"code","source":"def plot_metrics(metrics):\n    fig, axs = plt.subplots(4, 1, figsize=(10, 15))\n\n    epochs = range(1, len(metrics[\"train_losses\"]) + 1)\n\n    axs[0].plot(epochs, metrics[\"train_losses\"], label=\"Train\")\n    axs[0].plot(epochs, metrics[\"val_losses\"], label=\"Validation\")\n    axs[0].set_title(\"Train and Validation Loss\")\n    axs[0].set_xlabel(\"Epochs\")\n    axs[0].set_ylabel(\"Loss\")\n    axs[0].legend()\n\n    axs[1].plot(epochs, metrics[\"train_accs\"], label=\"Train\")\n    axs[1].plot(epochs, metrics[\"val_accs\"], label=\"Validation\")\n    axs[1].set_title(\"Train and Validation Accuracy\")\n    axs[1].set_xlabel(\"Epochs\")\n    axs[1].set_ylabel(\"Accuracy\")\n    axs[1].legend()\n\n    axs[2].plot(epochs, metrics[\"grads_norm_epochs\"])\n    axs[2].set_title(\"Gradients L2-Norm\")\n    axs[2].set_xlabel(\"Epochs\")\n    axs[2].set_ylabel(\"Gradients L2-Norm\")\n\n    axs[3].plot(epochs, metrics[\"sparsity_epochs\"])\n    axs[3].set_title(\"Sparsity\")\n    axs[3].set_xlabel(\"Epochs\")\n    axs[3].set_ylabel(\"Sparsity\")\n\n    plt.tight_layout()\n    plt.show()","metadata":{"id":"Q90DclagAZ1Z","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:12.803439Z","iopub.execute_input":"2024-12-31T22:32:12.803786Z","iopub.status.idle":"2024-12-31T22:32:12.811943Z","shell.execute_reply.started":"2024-12-31T22:32:12.803755Z","shell.execute_reply":"2024-12-31T22:32:12.810998Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Result Saver","metadata":{"id":"iCKzRHyBCLnZ"}},{"cell_type":"code","source":"def save_results(results, model, name):\n    with open(f\"./{name}_results.pkl\", \"wb\") as fp:\n        pickle.dump(results, fp)\n    torch.save(model, f\"./{name}_model.pb\")","metadata":{"id":"sAajtmw1CObd","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:14.718519Z","iopub.execute_input":"2024-12-31T22:32:14.719179Z","iopub.status.idle":"2024-12-31T22:32:14.723679Z","shell.execute_reply.started":"2024-12-31T22:32:14.719144Z","shell.execute_reply":"2024-12-31T22:32:14.722732Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Tok-k Accuracy","metadata":{}},{"cell_type":"code","source":"def topk_accuracy(logits, labels, k=1):\n    \"\"\"\n    Computes the Top-k Accuracy for a single batch.\n    Args:\n        logits (torch.Tensor): Model predictions of shape (batch_size, num_classes).\n        labels (torch.Tensor): Ground truth labels of shape (batch_size).\n        k (int): The 'k' in Top-k accuracy.\n    Returns:\n        float: Top-k accuracy.\n    \"\"\"\n    # Get the top-k predictions\n    topk_values, topk_indices = torch.topk(logits, k=k, dim=1)\n\n    # Check if the true label is among the top-k predictions\n    matches = (topk_indices == labels.unsqueeze(1))  # Compare along the top-k axis\n    topk_correct = matches.sum().item()  # Count correct predictions\n\n    # Calculate accuracy\n    accuracy = topk_correct / labels.size(0)  # Batch size\n    return accuracy * 100","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:15.987009Z","iopub.execute_input":"2024-12-31T22:32:15.987363Z","iopub.status.idle":"2024-12-31T22:32:15.992498Z","shell.execute_reply.started":"2024-12-31T22:32:15.987333Z","shell.execute_reply":"2024-12-31T22:32:15.991481Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Pruner","metadata":{"id":"zZj03NE0hh24"}},{"cell_type":"markdown","source":"### Prune Rate Scheduler","metadata":{}},{"cell_type":"code","source":"def get_prune_rate(start_value, end_value, total_epochs, current_epoch):\n    \"\"\"Calculate the hyperparameter value for a specific epoch, changing linearly from start_value to end_value.\"\"\"\n    # Linearly decreasing or increasing or constant\n    # current epochs: [0, total_epochs]\n    return start_value + current_epoch / total_epochs * (end_value - start_value)","metadata":{"id":"ZPR_ILNwmpbB","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:17.638077Z","iopub.execute_input":"2024-12-31T22:32:17.638476Z","iopub.status.idle":"2024-12-31T22:32:17.643268Z","shell.execute_reply.started":"2024-12-31T22:32:17.638445Z","shell.execute_reply":"2024-12-31T22:32:17.642228Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Quantile Estimation","metadata":{}},{"cell_type":"code","source":"def mean_chunked_quantiles(tensor, quantile, chunks):\n    split_tensors = torch.chunk(tensor, chunks)\n    quantiles = [torch.quantile(t, quantile) for t in split_tensors]\n    return torch.mean(torch.tensor(quantiles))","metadata":{"id":"K1u_t90VfbGn","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:18.062032Z","iopub.execute_input":"2024-12-31T22:32:18.062395Z","iopub.status.idle":"2024-12-31T22:32:18.067040Z","shell.execute_reply.started":"2024-12-31T22:32:18.062364Z","shell.execute_reply":"2024-12-31T22:32:18.066149Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Pruner Class","metadata":{}},{"cell_type":"code","source":"class MyPruner:\n    def __init__(self, model, device, prune_rate_start=0.2, prune_rate_end=0.95, decay_param=0.9, prune_threshold=1e-5,\n                 quantile_chunks=1, rewind_mode=\"rewind\", decay_step_mode=\"rewind\"):\n        \"\"\"Initialize the pruner with model reference and configuration parameters.\"\"\"\n        self.model = model\n        self.device = device\n        self.prune_rate = prune_rate_start\n        self.prune_rate_end = prune_rate_end\n        self.decay_param = decay_param\n        self.prune_threshold = prune_threshold\n        self.quantile_chunks = quantile_chunks\n        self.rewind_mode = rewind_mode\n        self.decay_step_mode = decay_step_mode\n\n        self.epoch = 0\n        self.decay_step = 1\n\n        self.pruned_weights_mask = {}\n\n        self.rewinding_weights = {}\n        self.rewinding_epoch = 1 if self.rewind_mode == \"keep\" else 0\n\n        # SSG = Sum of Squared Gradients\n        self.ssg_tensor = {}\n        self.min_ssg_scalar = float(\"inf\")\n\n        self.init_weights_and_masks()\n\n    def init_weights_and_masks(self):\n        \"\"\"Initialize weights and masks for each parameter in the model.\"\"\"\n        with torch.no_grad():\n            for name, param in self.model.named_parameters():\n                if param.requires_grad:\n                    self.rewinding_weights[name] = param.data.clone().to(self.device)\n                    self.pruned_weights_mask[name] = torch.zeros_like(param, dtype=torch.bool, device=self.device)\n                    self.ssg_tensor[name] = torch.zeros_like(param, device=self.device)\n\n    def reset_ssg_tensor(self):\n        \"\"\"Reset ssg_tensor to all zeros at the start of each loop.\"\"\"\n        with torch.no_grad():\n            for name in self.ssg_tensor:\n                self.ssg_tensor[name].zero_()\n\n    def update_ssg_tensor(self):\n        \"\"\"Accumulate squared gradients for each parameter during training.\"\"\"\n        with torch.no_grad():\n            for name, param in self.model.named_parameters():\n                if param.requires_grad and param.grad != None:\n                    self.ssg_tensor[name] += param.grad.pow(2)\n\n    def update_rewinding_weights(self):\n        \"\"\"Check and update the rewinding weights based on the current sum of squared gradients.\"\"\"\n        with torch.no_grad():\n            current_ssg_scalar = sum(self.ssg_tensor[name].sum().item() for name in self.ssg_tensor)\n            if current_ssg_scalar < self.min_ssg_scalar:\n                self.min_ssg_scalar = current_ssg_scalar\n                if self.rewind_mode == \"rewind\":\n                    self.rewinding_epoch = self.epoch\n                    print(f\"Rewinding weights updated to W(k={self.rewinding_epoch})\")\n                    for name, param in self.model.named_parameters():\n                        if param.requires_grad:\n                                self.rewinding_weights[name] = param.data.clone()\n                if self.decay_step_mode == \"rewind\":\n                    self.decay_step = 1\n                    print(f\"Decay step reset to {self.decay_step}\")\n                    \n            if self.rewind_mode == \"keep\":\n                self.rewinding_epoch = self.epoch + 1\n\n    def prune(self):\n        \"\"\"Prune the model by applying decay and thresholding based on ssg tensor.\"\"\"\n        with torch.no_grad():\n            ssg_tensor_flatten = torch.cat([self.ssg_tensor[name].view(-1) for name in self.ssg_tensor])\n            threshold_quantile = mean_chunked_quantiles(ssg_tensor_flatten, self.prune_rate,\n                                                        chunks=self.quantile_chunks)\n            \n            decay_factor = self.decay_param ** self.decay_step\n\n            for name, param in self.model.named_parameters():\n                if param.requires_grad:\n                    retain_mask = self.ssg_tensor[name] >= threshold_quantile\n\n                    if self.rewind_mode == \"keep\":\n                        param.data[~retain_mask] *= decay_factor\n                    else:\n                        param.data[~retain_mask] = self.rewinding_weights[name][~retain_mask] * decay_factor\n\n                    # Apply thresholding to enforce zeroing out small weights\n                    small_weights_mask = torch.abs(param.data) < self.prune_threshold\n                    self.pruned_weights_mask[name] |= small_weights_mask\n                    param.data[self.pruned_weights_mask[name]] = 0.0\n\n        self.increase_index()\n\n    def increase_index(self):\n        self.epoch += 1\n        if self.decay_step_mode != \"nostep\":\n            self.decay_step += 1","metadata":{"id":"_9lx-wtISMlw","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:19.314620Z","iopub.execute_input":"2024-12-31T22:32:19.314947Z","iopub.status.idle":"2024-12-31T22:32:19.328684Z","shell.execute_reply.started":"2024-12-31T22:32:19.314917Z","shell.execute_reply":"2024-12-31T22:32:19.327650Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Trainer","metadata":{"id":"hHEO4Kefn__Q"}},{"cell_type":"markdown","source":"### LR Scheduler","metadata":{}},{"cell_type":"code","source":"class CustomLRScheduler(torch.optim.lr_scheduler._LRScheduler):\n    def __init__(self, optimizer, interval=30, last_epoch=-1):\n        \"\"\"\n        Custom learning rate scheduler that divides the learning rate by 2 every 'interval' epochs.\n\n        Args:\n            optimizer (Optimizer): Wrapped optimizer.\n            interval (int): Number of epochs between learning rate updates.\n            last_epoch (int): The index of the last epoch. Default: -1.\n        \"\"\"\n        self.interval = interval\n        super().__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        \"\"\"\n        Compute the learning rates for this epoch.\n\n        Returns:\n            list of float: Updated learning rates for each parameter group.\n        \"\"\"\n        if self.last_epoch < 0 or (self.last_epoch + 1) % self.interval != 0:\n            # Return the current learning rate if no update is needed\n            return [group['lr'] for group in self.optimizer.param_groups]\n\n        # Return the updated learning rate (divide by 2)\n        return [group['lr'] / 2 for group in self.optimizer.param_groups]","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:23.753994Z","iopub.execute_input":"2024-12-31T22:32:23.754835Z","iopub.status.idle":"2024-12-31T22:32:23.760497Z","shell.execute_reply.started":"2024-12-31T22:32:23.754800Z","shell.execute_reply":"2024-12-31T22:32:23.759547Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Train one epoch","metadata":{"id":"0I16fLYyAqSu"}},{"cell_type":"code","source":"def train_one_epoch(model, train_loader, optimizer, criterion, pruner, device, topk):\n    \"\"\"Train the model for one epoch, return metrics.\"\"\"\n    model.train()\n\n    epoch_loss = 0.0\n    epoch_topk_acc = 0.0\n    epoch_sum_squared_grads = 0.0\n\n    for inputs, labels in train_loader:\n        inputs, labels = inputs.to(device), labels.to(device)\n        optimizer.zero_grad()\n        outputs = model(inputs)\n        loss = criterion(outputs, labels)\n        loss.backward()\n        optimizer.step()\n\n        if pruner:\n            pruner.update_ssg_tensor()\n\n        with torch.no_grad():\n            epoch_loss += loss.item()\n            epoch_topk_acc += topk_accuracy(outputs, labels, k=topk)\n\n            sum_squared_grads = sum(\n                (p.grad.data.norm() ** 2).item()\n                for p in model.parameters()\n                if p.requires_grad and p.grad is not None\n            )\n\n            epoch_sum_squared_grads += sum_squared_grads\n\n    # L2-Norm of gradients\n    epoch_grads_norm = epoch_sum_squared_grads ** 0.5\n\n    # Average metrics across batches\n    epoch_loss /= len(train_loader)\n    epoch_topk_acc /= len(train_loader)\n    epoch_grads_norm /= len(train_loader)\n\n    return epoch_loss, epoch_topk_acc, epoch_grads_norm","metadata":{"id":"YniXfP8EpRjU","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:25.596950Z","iopub.execute_input":"2024-12-31T22:32:25.597800Z","iopub.status.idle":"2024-12-31T22:32:25.606927Z","shell.execute_reply.started":"2024-12-31T22:32:25.597755Z","shell.execute_reply":"2024-12-31T22:32:25.606014Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Validator","metadata":{"id":"qrwnr6VsAsy2"}},{"cell_type":"code","source":"def validate(model, test_loader, criterion, device, topk):\n    \"\"\"Validate the model, return loss and accuracy.\"\"\"\n    model.eval()\n\n    val_loss = 0.0\n    val_topk_acc = 0.0\n\n    with torch.no_grad():\n        for inputs, labels in test_loader:\n            inputs, labels = inputs.to(device), labels.to(device)\n            outputs = model(inputs)\n            loss = criterion(outputs, labels)\n\n            val_loss += loss.item()\n            val_topk_acc += topk_accuracy(outputs, labels, k=topk)\n\n    # Average metrics across batches\n    val_loss /= len(test_loader)\n    val_topk_acc /= len(test_loader)\n\n    return val_loss, val_topk_acc","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:27.334465Z","iopub.execute_input":"2024-12-31T22:32:27.335267Z","iopub.status.idle":"2024-12-31T22:32:27.340657Z","shell.execute_reply.started":"2024-12-31T22:32:27.335230Z","shell.execute_reply":"2024-12-31T22:32:27.339574Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Wrapper","metadata":{"id":"-chbpcG5AwVv"}},{"cell_type":"code","source":"def run_training(experiment_name, train_ds, test_ds, model, criterion, optimizer, scheduler, epochs, batch_size,\n                 val_batch_size, pruner, device, topk, save_each_epoch, num_workers):\n    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True,\n                                               num_workers=num_workers)\n    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=val_batch_size, shuffle=False, pin_memory=True,\n                                              num_workers=num_workers)\n\n    train_loss_list = []\n    train_acc_list = []\n    val_loss_list = []\n    val_acc_list = []\n    grads_norm_list = []\n    sparsity_list = []\n    total_time = 0.0  # Minutes\n\n    if pruner:\n        start_prune_rate = pruner.prune_rate\n        end_prune_rate = pruner.prune_rate_end\n\n    for epoch in range(epochs):\n        start_time = time.time()\n        print(f\"Epoch: {epoch + 1}\")\n\n        if pruner:\n            pruner.prune_rate = get_prune_rate(start_prune_rate, end_prune_rate, epochs, epoch)\n\n        train_loss, train_acc, grads_norm = train_one_epoch(\n            model, train_loader, optimizer, criterion, pruner, device, topk\n        )\n\n        print(f\"Train Loss: {round(train_loss, 3)}\")\n        print(f\"Train Acc: {round(train_acc, 3)}\")\n\n        if pruner:\n            print(f\"Prune Rate: {pruner.prune_rate}\")\n            print(f\"Decay Step: {pruner.decay_step}\")\n            print(f\"Pruned Rewinding Weights: W(k={pruner.rewinding_epoch})\")\n            print(f\"Grads L2-Norm: {round(grads_norm, 5)}\")\n\n            pruner.prune()\n            pruner.update_rewinding_weights()\n            pruner.reset_ssg_tensor()\n\n        sparsity = calculate_sparsity(model)\n\n        val_loss, val_acc = validate(model, test_loader, criterion, device, topk)\n\n        if scheduler:\n            scheduler.step()\n\n        end_time = time.time()\n        epoch_time = (end_time - start_time) / 60\n\n        print(f\"Val Loss: {round(val_loss, 3)}\")\n        print(f\"Val Acc: {round(val_acc, 3)}\")\n        print(f\"Sparsity: {round(sparsity, 3)}\")\n        print(f\"Time: {round(epoch_time, 3)}\")\n        print(\"-\" * 40)\n\n        train_loss_list.append(train_loss)\n        train_acc_list.append(train_acc)\n        val_loss_list.append(val_loss)\n        val_acc_list.append(val_acc)\n        grads_norm_list.append(grads_norm)\n        sparsity_list.append(sparsity)\n        total_time += epoch_time\n\n        if save_each_epoch or (epoch + 1 == epochs):\n            results = {\n                \"train_losses\": train_loss_list,\n                \"train_accs\": train_acc_list,\n                \"val_losses\": val_loss_list,\n                \"val_accs\": val_acc_list,\n                \"grads_norm_epochs\": grads_norm_list,\n                \"sparsity_epochs\": sparsity_list,\n                \"total_time\": total_time\n            }\n            save_results(results, model, f\"{experiment_name}_epoch{epoch + 1}\")\n\n    print(f\"Total Time: {round(total_time, 3)} minutes\")\n\n    return results, model","metadata":{"id":"UHCQLzcVALsL","trusted":true,"execution":{"iopub.status.busy":"2024-12-31T22:32:29.755512Z","iopub.execute_input":"2024-12-31T22:32:29.756286Z","iopub.status.idle":"2024-12-31T22:32:29.766291Z","shell.execute_reply.started":"2024-12-31T22:32:29.756250Z","shell.execute_reply":"2024-12-31T22:32:29.765435Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Experiments","metadata":{"id":"ssJHh1TMoCJB"}},{"cell_type":"code","source":"experiment_name = \"DenseNet-CIFAR-Prune-Rewind\"\n\nmodel = densenet_cifar().to(device)\ntrain_ds = cifar10_train_dataset\ntest_ds = cifar10_test_dataset\n\nepochs = 301\nbatch_size = 128\nval_batch_size = 100\ntopk = 1\nsave_each_epoch = False\nnum_workers = 4\n\npruner = MyPruner(\n    model,\n    device,\n    prune_rate_start=0.95,\n    prune_rate_end=0.25,\n    decay_param=0.75,\n    prune_threshold=5.5e-5,\n    quantile_chunks=1,\n    rewind_mode=\"rewind\",  # rewind, reset, keep\n    decay_step_mode=\"rewind\"  # epoch, rewind, nostep\n)\n\n# pruner = None\n\noptimizer = torch.optim.SGD(\n    model.parameters(),\n    lr=0.05,\n    momentum=0.9,\n    weight_decay=5e-4\n)\n\n# optimizer = torch.optim.Adam(\n#     model.parameters(),\n#     lr=12e-4\n# )\n\n# scheduler = CustomLRScheduler(optimizer, interval=30)\nscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)\n# scheduler = None\n\ncriterion = nn.CrossEntropyLoss()\n\nresults, model = run_training(\n    experiment_name=experiment_name,\n    train_ds=train_ds,\n    test_ds=test_ds,\n    model=model,\n    criterion=criterion,\n    optimizer=optimizer,\n    scheduler=scheduler,\n    epochs=epochs,\n    batch_size=batch_size,\n    val_batch_size=val_batch_size,\n    pruner=pruner,\n    device=device,\n    topk=topk,\n    save_each_epoch=save_each_epoch,\n    num_workers=num_workers\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T19:32:44.885059Z","iopub.execute_input":"2024-12-31T19:32:44.885406Z","iopub.status.idle":"2024-12-31T21:49:40.345500Z","shell.execute_reply.started":"2024-12-31T19:32:44.885375Z","shell.execute_reply":"2024-12-31T21:49:40.344379Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"plot_metrics(results)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-12-31T21:50:22.910240Z","iopub.execute_input":"2024-12-31T21:50:22.910987Z","iopub.status.idle":"2024-12-31T21:50:23.945121Z","shell.execute_reply.started":"2024-12-31T21:50:22.910949Z","shell.execute_reply":"2024-12-31T21:50:23.944277Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Cleaning\n\nmodel = None\noptimizer = None\npruner = None\ntorch.cuda.empty_cache()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}