{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3046844b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import scipy.stats\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "from numba import njit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "489e6cda",
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_margin(X, p = 2):\n",
    "    n, d = X.shape\n",
    "    \n",
    "    w = cp.Variable(d)\n",
    "    obj = cp.Minimize(cp.pnorm(w, p))\n",
    "    constr = [w.T @ x >= 1 for x in X]\n",
    "    \n",
    "    prob = cp.Problem(obj, constr)\n",
    "    prob.solve()\n",
    "    \n",
    "    if 'optimal' in prob.status:\n",
    "        return np.array(w.value)\n",
    "    else:\n",
    "        print(prob.status)\n",
    "        return np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d02fad19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def margin(w, X):\n",
    "    return min(w.T @ x for x in X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd0ae18e",
   "metadata": {},
   "outputs": [],
   "source": [
    "@njit\n",
    "def clsf_md(X, w0, it, eta, p = 2.0):\n",
    "    n, d = X.shape\n",
    "    \n",
    "    w = w0\n",
    "    \n",
    "    for t in range(it):\n",
    "        dl = np.zeros(d)\n",
    "        for x in X:\n",
    "            dl += np.exp(-np.dot(w, x)) * x\n",
    "        tmp = np.power(np.abs(w), p-1) * np.sign(w) + eta * dl\n",
    "        w = np.power(np.abs(tmp), 1/(p-1)) * np.sign(tmp)\n",
    "\n",
    "    return w  #, hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93462682",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, d = 15, 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e65c84fe-9083-4c8d-b3ee-718241c2ae79",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(2)\n",
    "#np.random.seed(99)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aabde6fa-8234-44ff-8dc9-1c0cde086eb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_pt = scipy.stats.norm.rvs(0, 0.1, size=d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ba80f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.vstack([sp.sparse.random(1, d, density = 0.1, data_rvs = sp.stats.uniform(-2, 4).rvs).toarray() for _ in range(n)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1163ee96-5e39-4734-8dd5-ff3a9555a452",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = [1.1, 1.5, 2, 3, 6, 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c3d1c61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "w = [clsf_md(X, init_pt, 250000, 1e-4, p = p0) for p0 in p]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bba86940-a3d8-4bcb-b30b-dc40655ca0b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for w0, p0 in zip(w, p):\n",
    "    print(f'$p={p0}$ &', f'{np.linalg.norm(w0 / margin(w0, X), 1):.3f} &',\n",
    "          ' & '.join(f'{np.linalg.norm(w0 / margin(w0, X), p0):.3f}' for p0 in p), \n",
    "          f'& {np.linalg.norm(w0 / margin(w0, X), np.inf):.3f}', \n",
    "          '\\\\\\\\')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0920c08d-27f2-4d5d-be5a-8b7bce6eb3ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "for w0, p0 in zip(w, p):\n",
    "    print(f'$p={p0}$ &',\n",
    "          ' & '.join(f'{np.linalg.norm(w0 / margin(w0, X), p0):.3f}' for p0 in p), \n",
    "          '\\\\\\\\')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
