{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "030fa828",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import scipy.stats\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "from numba import njit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fad836a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('figure', figsize=(5,5))\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', size=20)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rc('font', serif=['Computer Modern Roman'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a19c70b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_margin(X, p = 2):\n",
    "    n, d = X.shape\n",
    "    \n",
    "    w = cp.Variable(d)\n",
    "    obj = cp.Minimize(cp.pnorm(w, p))\n",
    "    constr = [w.T @ x >= 1 for x in X]\n",
    "    \n",
    "    prob = cp.Problem(obj, constr)\n",
    "    prob.solve()\n",
    "    \n",
    "    return w.value / prob.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a16ece",
   "metadata": {},
   "outputs": [],
   "source": [
    "def margin(w, X):\n",
    "    return min(w.T @ x for x in X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89ad4469",
   "metadata": {},
   "outputs": [],
   "source": [
    "@njit\n",
    "def dual_angle(x, y, p = 2.0):\n",
    "    x, y = x / np.linalg.norm(x, p), y / np.linalg.norm(y, p)\n",
    "    return np.dot(np.power(np.abs(x), p-1) * np.sign(x), y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "083c84e0-2599-40df-a2a7-f4c828dd90dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "@njit\n",
    "def clsf_help(X, w0, it, eta, wp, p = 2.0):\n",
    "    n, d = X.shape\n",
    "    w = w0\n",
    "\n",
    "    hist = np.zeros((3, it))\n",
    "    \n",
    "    for t in range(it):\n",
    "        dl = np.exp(-X @ w) @ X\n",
    "        tmp = np.power(np.abs(w), p-1) * np.sign(w) + eta * dl\n",
    "        w = np.power(np.abs(tmp), 1/(p-1)) * np.sign(tmp)\n",
    "        \n",
    "        loss = np.sum(np.exp(-X @ w))\n",
    "        \n",
    "        da = dual_angle(w, wp, p)\n",
    "        hist[:, t] = [loss / n, 1 - da, np.linalg.norm(w, p)]\n",
    "\n",
    "    return w, hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "745557d0-0b64-4703-aeb7-2dfcc3cda7ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clsf_md(X, w0, it, eta, p = 2.0):\n",
    "    wp = max_margin(X, p)\n",
    "    return clsf_help(X, w0, it, eta, wp, p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53307779-61c0-4256-a2af-4b578d1697ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ba95968",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, d = 12, 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b91dbfee-a772-43c4-a9fc-e44d31f01ec9",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_pt = scipy.stats.norm.rvs(0, 1, size=(d,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6cec7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = scipy.stats.norm.rvs(np.array([0.5, 0.5]), 0.15, size=(n,d))\n",
    "X = np.vstack((X, 2/3 * np.array([[0.25, 0.75], [0.75, 0.25], [0.5, 0.5]])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b1beb0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "rs = np.random.random_sample(n+3) > 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94bd3f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.xlim([-1.0, 1.0])\n",
    "plt.ylim([-1.0, 1.0])\n",
    "plt.grid(alpha=0.6)\n",
    "\n",
    "plt.scatter(X[:,0][rs], X[:,1][rs], color='red')\n",
    "plt.scatter(-X[:,0][np.logical_not(rs)], -X[:,1][np.logical_not(rs)], color='blue', marker = 'x')\n",
    "\n",
    "wp = max_margin(X, 2)\n",
    "plt.plot(np.linspace(-2, 2, 100) * -wp[1], np.linspace(-2, 2, 100) * wp[0], color='olive', linestyle='-.')\n",
    "\n",
    "plt.savefig('scatter.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23b9eed",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "w1, (a1, b1, c1) = clsf_md(X, init_pt, 1000000, 1e-4, p = 2)\n",
    "w2, (a2, b2, c2) = clsf_md(X, init_pt, 1000000, 1e-4, p = 3)\n",
    "w3, (a3, b3, c3) = clsf_md(X, init_pt, 1000000, 1e-4, p = 1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe91086",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.arange(1, len(a1)+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2950976",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(t, a3, label='$p=1.5$')\n",
    "plt.plot(t, a1, label='$p=2$')\n",
    "plt.plot(t, a2, label='$p=3$')\n",
    "\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "plt.ylabel('Loss')\n",
    "plt.xlabel('Iteration')\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('loss.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0ee84d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "ax1.plot(t, b3, label='$p=1.5$')\n",
    "ax1.plot(t, b1, label='$p=2$')\n",
    "ax1.plot(t, b2, label='$p=3$')\n",
    "\n",
    "ax1.set_xscale('log')\n",
    "ax1.set_ylabel('Convergence gap')\n",
    "ax1.set_xlabel('Iteration')\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('angle.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7c9747",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.plot(t, c3, label='$p=1.5$')\n",
    "plt.plot(t, c1, label='$p=2$')\n",
    "plt.plot(t, c2, label='$p=3$')\n",
    "\n",
    "plt.xscale('log')\n",
    "plt.ylabel('$||w_t||_p$', fontsize=22)\n",
    "plt.xlabel('Iteration')\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig('norm.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
