{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('path_to_the_folder')\n",
    "from generalized_VI_ICLR.algorithms import Algorithm, SGD, Popov, Extragradient\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import math\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "from matplotlib.patches import Rectangle\n",
    "import matplotlib.patches as patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    'font.size'   : 20})\n",
    "\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(n_exper, n_samples, dim, mu, L, n_steps=20000, sampler=None, trashold=None, noise=False, twod=False, p=1.5, lr_p=None, scheduler_lr=\"q\", closeness=1, q=2/3):\n",
    "    results = defaultdict(list)\n",
    "    results[\"projection\"] = []\n",
    "    results[\"popov\"] = []\n",
    "    results[\"kappa\"] = []\n",
    "    for i in range(n_exper):\n",
    "        if i % 5 == 0:\n",
    "            print(i)\n",
    "        # J = torch.tensor([[[0.0, 1.0],\n",
    "        # [-1.0, 0.0 ]]])\n",
    "        # abs =  torch.zeros(1, 2)\n",
    "        init_points = [closeness * torch.tensor([13.1, 7.2]), closeness * torch.tensor([12.2, 7.5])]\n",
    "        sgd = SGD(p=p, noise=noise, init=init_points)\n",
    "        res = sgd.run(n_steps=n_steps, sampler=sampler, scheduler=\"q\", q=q, trashold=1)\n",
    "        results[\"projection\"].append(res[\"Dist2Sol\"])\n",
    "\n",
    "        sgd = SGD(p=p, noise=noise, init=init_points, same=True)\n",
    "        res = sgd.run(n_steps=n_steps, sampler=sampler, scheduler=\"q\", q=q, trashold=1)\n",
    "        results[\"projection_same\"].append(res[\"Dist2Sol\"])\n",
    "        \n",
    "        popov = Popov(p=p, lr_0=lr_p, noise=noise, init=init_points)\n",
    "        res_popov = popov.run(n_steps=n_steps, sampler=sampler, scheduler=scheduler_lr, q=q, trashold=trashold)\n",
    "        results[\"popov\"].append(res_popov[\"Dist2Sol\"])\n",
    "\n",
    "        extra = Extragradient( p=p, lr_0=lr_p, noise=noise, init=init_points)\n",
    "        res_extra = extra.run(n_steps=n_steps, sampler=sampler, scheduler=scheduler_lr, q=q, trashold=trashold)\n",
    "\n",
    "        results[\"extra\"].append(res_extra[\"Dist2Sol\"])\n",
    "        results[\"kappa\"].append(extra.L / extra.mu)\n",
    "    results[\"extra\"] = np.array(results[\"extra\"])\n",
    "    results[\"popov\"] = np.array(results[\"popov\"])\n",
    "    results[\"projection\"] = np.array(results[\"projection\"])\n",
    "    results[\"projection_same\"] = np.array(results[\"projection_same\"])\n",
    "    results[\"kappa\"] = np.array(results[\"kappa\"])\n",
    "    results[\"points_proj\"] = res[\"u\"]\n",
    "    results[\"points_pop\"] = res_popov[\"u\"]\n",
    "    results[\"points_extr\"] = res_extra[\"u\"]\n",
    "    plt.figure(figsize=(8,6))\n",
    "    plt.plot(results[\"projection\"].mean(0), label=\"Projection\", marker=\"o\", markevery=n_steps//10)\n",
    "    plt.plot(results[\"projection_same\"].mean(0), label=\"Projection-Same\", marker=\"o\", markevery=n_steps//10)\n",
    "    plt.plot(results[\"popov\"].mean(0), label=\"Popov\",   marker=\"o\", markevery=n_steps//10)\n",
    "    plt.plot(results[\"extra\"].mean(0), label=\"Korpelevich\",   marker=\"s\", markevery=n_steps//10)\n",
    "    plt.yscale('log')\n",
    "\n",
    "    plt.title('Convergence', fontsize=20)\n",
    "    plt.ylabel(\"Distance to solution\", fontsize=\"15\")\n",
    "    plt.xlabel(\"Number of iterations\", fontsize=\"15\")\n",
    "    if trashold is  None:\n",
    "        trashold = n_steps//2\n",
    "    plt.legend( fontsize=\"15\")\n",
    "    plt.savefig(\"Exp_p_{:.1f}_sch_{}_close_{}_noise{}_q{}.pdf\".format(p, scheduler_lr, closeness, noise, q))\n",
    "    plt.close() \n",
    "        \n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiments ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n",
      "0\n",
      "5\n",
      "10\n",
      "15\n"
     ]
    }
   ],
   "source": [
    "dim=1\n",
    "n_samples = 1\n",
    "mu = 0.05\n",
    "L = 1.0\n",
    "qs = [2/3]\n",
    "qs = [1/2+0.001, 2/3, 1 - 0.001]\n",
    "n_exper=20\n",
    "ps = [2.5, 3.0, 6.0]\n",
    "close = [1.0]\n",
    "n_steps = 1000\n",
    "schs = [\"q\"]\n",
    "for q in qs:\n",
    "    for p in ps:\n",
    "        for cl in close:\n",
    "            for sch in schs:\n",
    "                res__ =run_experiment(n_exper, n_samples, dim, mu, L, n_steps=n_steps, p=p, sampler=None, trashold=1, noise=True, scheduler_lr=sch, closeness=cl, q=q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.1622776601683795"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.2 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
