{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyMYSijN9MRXG33lbQRrRz08"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","execution_count":10,"metadata":{"id":"forGJyrvGk6N","executionInfo":{"status":"ok","timestamp":1747167052287,"user_tz":-540,"elapsed":19,"user":{"displayName":"Nyang","userId":"06114252970420503354"}}},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torchvision import datasets, transforms\n","from torch.utils.data import DataLoader\n","\n","# assertion to make sure every outputs of each layers are binary\n","def assert_binary(tensor, name=\"\"):\n","    if not torch.all((tensor == 0) | (tensor == 1)):\n","        raise ValueError(f\"[{name}] Tensor is not strictly binary (0 or 1).\")\n","\n","# Round-based STE to avoid floating point error from bin_op function\n","class RoundSTE(torch.autograd.Function):\n","    @staticmethod\n","    def forward(ctx, input):\n","        return torch.round(input)\n","    @staticmethod\n","    def backward(ctx, grad_output):\n","        return grad_output\n","\n","# Operand Selector\n","class Operand_selector(nn.Module):\n","    def __init__(self, x, y):\n","        super().__init__()\n","        self.p = nn.Linear(x, y, bias=False)\n","\n","    def forward(self, x):\n","        w = self.p.weight\n","        mask = torch.zeros_like(w).scatter_(1, w.argmax(dim=-1, keepdim=True), 1.0)\n","        masked = mask - w.detach() + w\n","        out = F.linear(x, masked)\n","        assert_binary(out, \"OperandSelector\")\n","        return out\n","\n","# Logic Gate Operator\n","def bin_op(a, b, i):\n","    if i == 0: return torch.zeros_like(a)\n","    elif i == 1: return a * b\n","    elif i == 2: return a - a * b\n","    elif i == 3: return a\n","    elif i == 4: return b - a * b\n","    elif i == 5: return b\n","    elif i == 6: return a + b - 2 * a * b\n","    elif i == 7: return a + b - a * b\n","    elif i == 8: return 1 - (a + b - a * b)\n","    elif i == 9: return 1 - (a + b - 2 * a * b)\n","    elif i == 10: return 1 - b\n","    elif i == 11: return 1 - b + a * b\n","    elif i == 12: return 1 - a\n","    elif i == 13: return 1 - a + a * b\n","    elif i == 14: return 1 - a * b\n","    elif i == 15: return torch.ones_like(a)\n","\n","def bin_op_s(a, b, i_s):\n","    r = torch.zeros_like(a)\n","    for i in range(16):\n","        r += i_s[..., i] * bin_op(a, b, i)\n","    result = RoundSTE.apply(r)\n","    assert_binary(result, \"Operator Output\")\n","    return result\n","\n","class Operator(nn.Module):\n","    def __init__(self, y):\n","        super().__init__()\n","        self.weights = nn.Parameter(torch.randn(y, 16))\n","\n","    def forward(self, a, b):\n","        w = self.weights\n","        mask = torch.zeros_like(w).scatter_(1, w.argmax(dim=-1, keepdim=True), 1.0)\n","        masked = mask - w.detach() + w\n","        return bin_op_s(a, b, masked)\n","\n","# Logic Layers\n","class oslgn(nn.Module):\n","    def __init__(self, x, y):\n","        super().__init__()\n","        self.os1 = Operand_selector(x, y)\n","        self.os2 = Operand_selector(x, y)\n","        self.op = Operator(y)\n","    def forward(self, x):\n","        out = self.op(self.os1(x), self.os2(x))\n","        assert_binary(out, \"oslgn Layer Output\")\n","        return out\n","\n","class detached_oslgn(nn.Module):\n","    def __init__(self, x, y):\n","        super().__init__()\n","        self.os1 = Operand_selector(x, y)\n","        self.os2 = Operand_selector(x, y)\n","        self.op = Operator(y)\n","    def forward(self, x):\n","        return self.op(self.os1(x).detach(), self.os2(x).detach())\n","\n","# Models A and B\n","class ModelA(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        w = 512\n","        self.layers = nn.ModuleList([\n","            oslgn(28 * 28, w),\n","            oslgn(w, w),\n","            oslgn(w, 10)\n","        ])\n","    def forward(self, x):\n","        x = x.flatten(1)\n","        for layer in self.layers:\n","            x = layer(x)\n","        mask = torch.zeros_like(x).scatter(1, x.argmax(dim=1, keepdim=True), 1.0)\n","        return mask - x.detach() + x\n","\n","class ModelB(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        w = 512\n","        self.layers = nn.ModuleList([\n","            detached_oslgn(28 * 28, w),\n","            detached_oslgn(w, w),\n","            detached_oslgn(w, 10)\n","        ])\n","    def forward(self, x):\n","        x = x.flatten(1)\n","        for layer in self.layers:\n","            x = layer(x)\n","        mask = torch.zeros_like(x).scatter(1, x.argmax(dim=1, keepdim=True), 1.0)\n","        return mask - x.detach() + x\n","\n","# Dataset\n","def get_mnist_loader(batch_size=64):\n","    transform = transforms.Compose([\n","        transforms.ToTensor(),\n","        transforms.Lambda(lambda x: (x > 0.5).float())\n","    ])\n","    train = datasets.MNIST(root=\".\", train=True, download=True, transform=transform)\n","    test = datasets.MNIST(root=\".\", train=False, download=True, transform=transform)\n","    return DataLoader(train, batch_size=batch_size, shuffle=True), DataLoader(test, batch_size=256)\n","\n","# Training\n","def train_one_epoch(model, loader, optimizer, criterion):\n","    model.train()\n","    total_loss, correct, total = 0, 0, 0\n","    for x, y in loader:\n","        x, y = x.cuda(), y.cuda()\n","        optimizer.zero_grad()\n","        pred = model(x)\n","        loss = criterion(pred, y)\n","        loss.backward()\n","        optimizer.step()\n","        total_loss += loss.item() * x.size(0)\n","        correct += (pred.argmax(dim=1) == y).sum().item()\n","        total += x.size(0)\n","    return total_loss / total, correct / total\n"]},{"cell_type":"code","source":["def train_model_pair(model_a, model_b, loader, epochs=10, lr=1e-3):\n","    model_a, model_b = model_a.cuda(), model_b.cuda()\n","    opt_a = torch.optim.Adam(model_a.parameters(), lr=lr)\n","    opt_b = torch.optim.Adam(model_b.parameters(), lr=lr)\n","    criterion = nn.CrossEntropyLoss()\n","\n","    log = {\n","        'epoch': [],\n","        'loss_a': [],\n","        'acc_a': [],\n","        'loss_b': [],\n","        'acc_b': []\n","    }\n","\n","    for epoch in range(epochs):\n","        model_a.train()\n","        model_b.train()\n","        total_a, correct_a, loss_a = 0, 0, 0\n","        total_b, correct_b, loss_b = 0, 0, 0\n","\n","        for x, y in loader:\n","            x, y = x.cuda(), y.cuda()\n","\n","            # ----- Model A -----\n","            opt_a.zero_grad()\n","            out_a = model_a(x)\n","            la = criterion(out_a, y)\n","            la.backward()\n","            opt_a.step()\n","\n","            loss_a += la.item() * x.size(0)\n","            correct_a += (out_a.argmax(dim=1) == y).sum().item()\n","            total_a += x.size(0)\n","\n","            # ----- Model B -----\n","            opt_b.zero_grad()\n","            out_b = model_b(x)\n","            lb = criterion(out_b, y)\n","            lb.backward()\n","            opt_b.step()\n","\n","            loss_b += lb.item() * x.size(0)\n","            correct_b += (out_b.argmax(dim=1) == y).sum().item()\n","            total_b += x.size(0)\n","\n","        log['epoch'].append(epoch)\n","        log['loss_a'].append(loss_a / total_a)\n","        log['acc_a'].append(correct_a / total_a)\n","        log['loss_b'].append(loss_b / total_b)\n","        log['acc_b'].append(correct_b / total_b)\n","\n","        print(f\"[Epoch {epoch}] \"\n","              f\"Model A - Loss: {loss_a / total_a:.4f}, Acc: {correct_a / total_a:.4f} | \"\n","              f\"Model B - Loss: {loss_b / total_b:.4f}, Acc: {correct_b / total_b:.4f}\")\n","\n","    return log\n"],"metadata":{"id":"i8AeIRz0PGEu","executionInfo":{"status":"ok","timestamp":1747167054867,"user_tz":-540,"elapsed":11,"user":{"displayName":"Nyang","userId":"06114252970420503354"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["@torch.no_grad()\n","def evaluate(model, loader):\n","    model.eval()\n","    correct, total = 0, 0\n","    for x, y in loader:\n","        x, y = x.cuda(), y.cuda()\n","        pred = model(x)\n","        correct += (pred.argmax(dim=1) == y).sum().item()\n","        total += x.size(0)\n","    return correct / total\n"],"metadata":{"id":"ajlnNU8BOpm9","executionInfo":{"status":"ok","timestamp":1747167057387,"user_tz":-540,"elapsed":5,"user":{"displayName":"Nyang","userId":"06114252970420503354"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["train_loader, test_loader = get_mnist_loader()\n","\n","model_a = ModelA().cuda()\n","model_b = ModelB().cuda()\n","\n","log = train_model_pair(\n","    model_a=model_a,\n","    model_b=model_b,\n","    loader=train_loader,\n","    epochs=10,\n","    lr=1e-3\n",")\n","\n","acc_a = evaluate(model_a, test_loader)\n","acc_b = evaluate(model_b, test_loader)\n","\n","print(f\"Test Accuracy — Model A: {acc_a:.4f}, Model B: {acc_b:.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"bi-1PCR2NepF","executionInfo":{"status":"ok","timestamp":1747167343674,"user_tz":-540,"elapsed":285418,"user":{"displayName":"Nyang","userId":"06114252970420503354"}},"outputId":"a3b76fac-e443-41b6-afe1-63740f63dfe0"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["[Epoch 0] Model A - Loss: 2.1513, Acc: 0.3099 | Model B - Loss: 2.3626, Acc: 0.0986\n","[Epoch 1] Model A - Loss: 2.0766, Acc: 0.3846 | Model B - Loss: 2.3589, Acc: 0.1022\n","[Epoch 2] Model A - Loss: 2.0895, Acc: 0.3716 | Model B - Loss: 2.3498, Acc: 0.1114\n","[Epoch 3] Model A - Loss: 2.1239, Acc: 0.3372 | Model B - Loss: 2.3532, Acc: 0.1080\n","[Epoch 4] Model A - Loss: 2.1155, Acc: 0.3457 | Model B - Loss: 2.3580, Acc: 0.1031\n","[Epoch 5] Model A - Loss: 2.0330, Acc: 0.4282 | Model B - Loss: 2.3555, Acc: 0.1056\n","[Epoch 6] Model A - Loss: 2.0295, Acc: 0.4317 | Model B - Loss: 2.3561, Acc: 0.1050\n","[Epoch 7] Model A - Loss: 2.0316, Acc: 0.4296 | Model B - Loss: 2.3577, Acc: 0.1035\n","[Epoch 8] Model A - Loss: 2.0274, Acc: 0.4338 | Model B - Loss: 2.3591, Acc: 0.1020\n","[Epoch 9] Model A - Loss: 2.0278, Acc: 0.4333 | Model B - Loss: 2.3608, Acc: 0.1003\n","Test Accuracy — Model A: 0.4219, Model B: 0.0885\n"]}]}]}