{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import OrderedDict\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torchvision.datasets import MNIST\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torchvision import transforms\n",
    "\n",
    "import cooper\n",
    "from copy import deepcopy as copy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(torch.nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(LogisticRegression, self).__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.linear = torch.nn.Linear(input_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input_dim)\n",
    "        outputs = self.linear(x)\n",
    "        return outputs\n",
    "\n",
    "\n",
    "class NormConstrainedLogReg(cooper.ConstrainedMinimizationProblem):\n",
    "    def __init__(self):\n",
    "        self.criterion = torch.nn.CrossEntropyLoss()\n",
    "        super().__init__(is_constrained=True)\n",
    "\n",
    "    def closure(self, model, inputs, targets):\n",
    "\n",
    "        pred_logits = model.forward(inputs)\n",
    "        loss = self.criterion(pred_logits, targets)\n",
    "\n",
    "        # We want each row of W to have norm less than or equal to 1\n",
    "        # g(W) <= 1  ---> g(W) - 1 <= 0\n",
    "        ineq_defect = model.linear.weight.norm(dim=1) - 1\n",
    "\n",
    "        return cooper.CMPState(loss=loss, ineq_defect=ineq_defect, eq_defect=None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = MNIST(os.getcwd() + \"/data\", download=True, transform=transforms.ToTensor())\n",
    "train_set, val_set = random_split(dataset, [55000, 5000])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmp = NormConstrainedLogReg()\n",
    "formulation = cooper.LagrangianFormulation(cmp)\n",
    "\n",
    "model = LogisticRegression(784, 10)\n",
    "\n",
    "# primal_optimizer = cooper.optim.ExtraSGD(model.parameters(), lr=1e-3, momentum=0.9)\n",
    "# dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=5e-3)\n",
    "\n",
    "primal_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)\n",
    "dual_optimizer = cooper.optim.partial_optimizer(torch.optim.SGD, lr=5e-3)\n",
    "\n",
    "coop = cooper.ConstrainedOptimizer(\n",
    "    formulation=formulation,\n",
    "    primal_optimizer=primal_optimizer,\n",
    "    dual_optimizer=dual_optimizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iter_num = 0\n",
    "\n",
    "state_history = OrderedDict()\n",
    "\n",
    "for epoch in range(3):\n",
    "    for inputs, targets in DataLoader(train_set, batch_size=64):\n",
    "\n",
    "        coop.zero_grad()\n",
    "        lagrangian = formulation.composite_objective(\n",
    "            cmp.closure, model, inputs, targets\n",
    "        )\n",
    "        formulation.custom_backward(lagrangian)\n",
    "        coop.step(cmp.closure, model, inputs, targets)\n",
    "\n",
    "        if iter_num % 5 == 0:\n",
    "            state_history[iter_num] = {\n",
    "                \"cmp\": cmp.state,\n",
    "                \"dual\": copy(formulation.state()),\n",
    "            }\n",
    "\n",
    "        iter_num += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iters, loss_history = zip(\n",
    "    *[(iter_num, _[\"cmp\"].loss.item()) for (iter_num, _) in state_history.items()]\n",
    ")\n",
    "mult_hist = np.stack(\n",
    "    [_[\"dual\"][0].data.numpy() for (foo, _) in state_history.items()], axis=1\n",
    ")\n",
    "defect_hist = np.stack(\n",
    "    [_[\"cmp\"].ineq_defect.data.numpy() for (foo, _) in state_history.items()], axis=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(iters, mult_hist[0:1, ...].T)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(iters, loss_history)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct = 0, 0\n",
    "with torch.no_grad():\n",
    "    for inputs, targets in DataLoader(val_set, batch_size=64):\n",
    "        outputs = model.forward(inputs)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += targets.size(0)\n",
    "        correct += (predicted == targets).sum()\n",
    "accuracy = 100 * correct / total\n",
    "\n",
    "accuracy "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "act-lbl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
