{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import scipy\n",
    "from MomentumOptimizer_LieGroup_SOn import LieGroupSGD\n",
    "device=torch.device('cpu')\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.manual_seed(0)\n",
    "color_list=plt.rcParams['axes.prop_cycle'].by_key()['color']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from eig_val_decomp import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# kappa dependence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_trajectory(scheme, seed, dim, kappa, num_iter=5000):\n",
    "    torch.manual_seed(0)\n",
    "    eig_vals=generate_eig_value_artificial_conditional_number(dim, kappa)\n",
    "    A, sol_dict=eig_val_decomp_problem(eig_vals, device=device)\n",
    "    U=lambda X:eig_val_decomp_loss(A, X)\n",
    "    mu=sol_dict['mu']\n",
    "    L=sol_dict['L']\n",
    "    min_val=sol_dict['min_val']\n",
    "    X_sol=sol_dict['X_sol']\n",
    "    xi_noise=torch.randn_like(A)*0.01\n",
    "    xi_noise=xi_noise-xi_noise.T\n",
    "    g_init=X_sol@torch.matrix_exp(xi_noise)\n",
    "\n",
    "\n",
    "    \n",
    "    g=torch.clone(g_init)\n",
    "    g.requires_grad_(True)\n",
    "    g_last=g.clone()\n",
    "    g_star=X_sol\n",
    "    if scheme=='heavy_ball':\n",
    "        parameter_dict_HB=parameter_HB(mu, L)\n",
    "        optimizer=LieGroupSGD([g], lr=parameter_dict_HB['h'], gamma=parameter_dict_HB['gamma'], scheme=scheme)\n",
    "    elif scheme=='NAG_SC':\n",
    "        parameter_dict_NAG_SC=parameter_NAG_SC(mu, L)\n",
    "        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], gamma=parameter_dict_NAG_SC['gamma'], scheme=scheme)\n",
    "    elif scheme=='momentumless':\n",
    "        parameter_dict_NAG_SC=parameter_momentumless(mu, L)\n",
    "        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], scheme=scheme)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    loss_list=[]\n",
    "    lyap_list=[]\n",
    "    for i in tqdm(range(num_iter)):\n",
    "        loss=eig_val_decomp_loss(A, g)-min_val\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        g_last.copy_(g)\n",
    "        optimizer.step()\n",
    "        xi=optimizer.state[g]['xi']\n",
    "\n",
    "        if scheme=='heavy_ball':\n",
    "            lyap_list+=[lyap_HB({'h':parameter_dict_HB['h'], \n",
    "                'gamma': parameter_dict_HB['gamma'], \n",
    "                'U':U,\n",
    "                'g':g,\n",
    "                'xi':xi,\n",
    "                'g_star':g_star,\n",
    "                'g_last':g_last\n",
    "            }).item()]\n",
    "        elif scheme=='NAG_SC':\n",
    "            nabla_g_last=optimizer.state[g]['trivialized_grad_last']\n",
    "            lyap_list+=[lyap_NAG_SC({'h':parameter_dict_NAG_SC['h'], \n",
    "                'gamma': parameter_dict_NAG_SC['gamma'], \n",
    "                'U':U,\n",
    "                'g':g,\n",
    "                'xi':xi,\n",
    "                'g_star':g_star,\n",
    "                'g_last':g_last,\n",
    "                'nabla_g_last':nabla_g_last\n",
    "            }).item()]\n",
    "        elif scheme=='momentumless':\n",
    "            lyap_list+=[loss.item()]\n",
    "        loss_list+=[loss.item()]\n",
    "    return {'loss_list':loss_list, 'lyap_list':lyap_list}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_convergence_rate(lyap_list):\n",
    "    lyap_list=torch.Tensor(lyap_list)\n",
    "    return torch.max(lyap_list[1:]/lyap_list[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kappa_list=np.arange(100, 10000, 500)\n",
    "convergence_rate_dict={}\n",
    "for scheme in ['heavy_ball', 'NAG_SC', 'momentumless']:\n",
    "    convergence_rate_dict[scheme]=np.zeros_like(kappa_list, dtype=np.float64)\n",
    "    for i, kappa in enumerate(kappa_list):\n",
    "        result_dict=generate_trajectory(scheme, 0, 10, kappa, num_iter=100 if scheme!='momentumless' else 3000)\n",
    "        convergence_rate_dict[scheme][i]=get_convergence_rate(result_dict['lyap_list'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('convergence_rate_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(convergence_rate_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('convergence_rate_dict.pkl', 'rb') as f:\n",
    "    convergence_rate_dict=pickle.load(f)\n",
    "kappa_list=np.arange(100, 10000, 500)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import linregress\n",
    "plt.plot(kappa_list, 1-convergence_rate_dict['heavy_ball'], color=color_list[1], label='heavy_ball')\n",
    "result=linregress(kappa_list, convergence_rate_dict['heavy_ball']/(1-convergence_rate_dict['heavy_ball']))\n",
    "fit_list=torch.from_numpy(result.slope*kappa_list)\n",
    "# convergence_rate_list_fit=fit_list/(1+fit_list)\n",
    "plt.plot(kappa_list, 1/fit_list,  linestyle='--', dashes=(2,2), color=color_list[1], label=r'$C \\kappa^{-1}$')\n",
    "\n",
    "\n",
    "plt.plot(kappa_list, 1-convergence_rate_dict['NAG_SC'], color=color_list[2], label='NAG_SC')\n",
    "result=linregress(kappa_list, (convergence_rate_dict['NAG_SC']/(1-convergence_rate_dict['NAG_SC']))**2)\n",
    "fit_list=torch.from_numpy(result.slope*kappa_list)\n",
    "# convergence_rate_list_fit=torch.sqrt(fit_list)/(1+torch.sqrt(fit_list))\n",
    "plt.plot(kappa_list, 1/torch.sqrt(fit_list),  linestyle='--', dashes=(2,2), color=color_list[2], label=r'$C \\kappa^{-0.5}$')\n",
    "plt.legend(fontsize=12)\n",
    "plt.yscale('log')\n",
    "# plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))\n",
    "plt.ylabel('1-convergence rate', fontsize=18)\n",
    "plt.xlabel('condition number', fontsize=18)\n",
    "plt.xticks(fontsize=14, rotation=10)\n",
    "plt.yticks(fontsize=14, rotation=60)\n",
    "\n",
    "# plt.title('convergence rate for different condition numbers')\n",
    "plt.savefig('LEV_conv_kappa.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exponent_list=[2, 3, 4]\n",
    "color={'heavy_ball':color_list[0], 'NAG_SC':color_list[1], 'momentumless':color_list[2]}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "for scheme in ['heavy_ball', 'NAG_SC', 'momentumless']:\n",
    "    result_dict[scheme]={}\n",
    "    for kappa in 10**exponent_list:\n",
    "        result_dict[scheme][kappa]=generate_trajectory(scheme, 0, 10, kappa, num_iter=2000)\n",
    "\n",
    "with open('result_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(result_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('result_dict.pkl', 'rb') as f:\n",
    "    result_dict=pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for scheme in ['momentumless']:\n",
    "    for i in exponent_list:\n",
    "        kappa=10**i\n",
    "        plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))\n",
    "\n",
    "\n",
    "plt.yscale('log')\n",
    "# plt.text(1250,1e-2,'solid:      lyap func\\n dashed: loss func')\n",
    "plt.legend(fontsize=12)\n",
    "plt.yscale('log')\n",
    "plt.ylabel('loss value', fontsize=18)\n",
    "plt.xlabel('#iter', fontsize=18)\n",
    "plt.xticks(fontsize=14, rotation=10)\n",
    "plt.yticks(fontsize=14, rotation=60)\n",
    "plt.savefig('LEV_momentumless.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for scheme in ['heavy_ball']:\n",
    "    for i in exponent_list:\n",
    "        kappa=10**i\n",
    "        plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))\n",
    "        plt.plot(result_dict[scheme][kappa]['lyap_list'], color=color_list[i])\n",
    "\n",
    "\n",
    "plt.yscale('log')\n",
    "# plt.text(1250,1e-2,'solid:      lyap func\\n dashed: loss func')\n",
    "plt.legend(fontsize=12)\n",
    "plt.yscale('log')\n",
    "plt.ylabel('loss value', fontsize=18)\n",
    "plt.xlabel('#iter', fontsize=18)\n",
    "plt.xticks(fontsize=14, rotation=10)\n",
    "plt.yticks(fontsize=14, rotation=60)\n",
    "plt.savefig('LEV_heavy_ball.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for scheme in ['NAG_SC']:\n",
    "    for i in exponent_list:\n",
    "        kappa=10**i\n",
    "        plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))\n",
    "        plt.plot(result_dict[scheme][kappa]['lyap_list'], color=color_list[i])\n",
    "\n",
    "plt.yscale('log')\n",
    "# plt.text(1250,1e-8,'solid:      lyap func\\n dashed: loss func')\n",
    "plt.legend(fontsize=12)\n",
    "plt.yscale('log')\n",
    "plt.ylabel('loss value', fontsize=18)\n",
    "plt.xlabel('#iter', fontsize=18)\n",
    "plt.xticks(fontsize=14, rotation=10)\n",
    "plt.yticks(fontsize=14, rotation=60)\n",
    "plt.savefig('LEV_NAG_SC.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-convexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_trajectory(scheme, seed, dim, kappa, num_iter=5000):\n",
    "    torch.manual_seed(0)\n",
    "    eig_vals=generate_eig_value_artificial_conditional_number(dim, kappa)\n",
    "    A, sol_dict=eig_val_decomp_problem(eig_vals, device=device)\n",
    "    U=lambda X:eig_val_decomp_loss(A, X)\n",
    "    mu=sol_dict['mu']\n",
    "    L=sol_dict['L']\n",
    "    min_val=sol_dict['min_val']\n",
    "    X_sol=sol_dict['X_sol']\n",
    "    xi_noise=torch.randn_like(A)*0.01\n",
    "    xi_noise=xi_noise-xi_noise.T\n",
    "    g_init=X_sol[:, np.arange(X_sol.shape[1]-1, -1, -1)]@torch.matrix_exp(xi_noise)\n",
    "\n",
    "    \n",
    "    g=torch.clone(g_init)\n",
    "    g.requires_grad_(True)\n",
    "    g_last=g.clone()\n",
    "    g_star=X_sol\n",
    "    if scheme=='heavy_ball':\n",
    "        parameter_dict_HB=parameter_HB(mu, L)\n",
    "        optimizer=LieGroupSGD([g], lr=parameter_dict_HB['h'], gamma=parameter_dict_HB['gamma'], scheme=scheme)\n",
    "    elif scheme=='NAG_SC':\n",
    "        parameter_dict_NAG_SC=parameter_NAG_SC(mu, L)\n",
    "        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], gamma=parameter_dict_NAG_SC['gamma'], scheme=scheme)\n",
    "    elif scheme=='momentumless':\n",
    "        parameter_dict_NAG_SC=parameter_momentumless(mu, L)\n",
    "        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], scheme=scheme)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    loss_list=[]\n",
    "    lyap_list=[]\n",
    "    for i in tqdm(range(num_iter)):\n",
    "        loss=eig_val_decomp_loss(A, g)-min_val\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        g_last.copy_(g)\n",
    "        optimizer.step()\n",
    "        xi=optimizer.state[g]['xi']\n",
    "\n",
    "        if scheme=='heavy_ball':\n",
    "            lyap_list+=[lyap_HB({'h':parameter_dict_HB['h'], \n",
    "                'gamma': parameter_dict_HB['gamma'], \n",
    "                'U':U,\n",
    "                'g':g,\n",
    "                'xi':xi,\n",
    "                'g_star':g_star,\n",
    "                'g_last':g_last\n",
    "            }).item()]\n",
    "        elif scheme=='NAG_SC':\n",
    "            nabla_g_last=optimizer.state[g]['trivialized_grad_last']\n",
    "            lyap_list+=[lyap_NAG_SC({'h':parameter_dict_NAG_SC['h'], \n",
    "                'gamma': parameter_dict_NAG_SC['gamma'], \n",
    "                'U':U,\n",
    "                'g':g,\n",
    "                'xi':xi,\n",
    "                'g_star':g_star,\n",
    "                'g_last':g_last,\n",
    "                'nabla_g_last':nabla_g_last\n",
    "            }).item()]\n",
    "        loss_list+=[loss.item()]\n",
    "    return {'loss_list':loss_list, 'lyap_list':lyap_list}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_dict_non_convex={}\n",
    "result_dict_non_convex['NAG_SC']=generate_trajectory('NAG_SC', 0, 10, 100, num_iter=10000)\n",
    "result_dict_non_convex['heavy_ball']=generate_trajectory('heavy_ball', 0, 10, 100, num_iter=10000)\n",
    "with open('result_dict_non_convex.pkl', 'wb') as f:\n",
    "    pickle.dump(result_dict_non_convex, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('result_dict_non_convex.pkl', 'rb') as f:\n",
    "    result_dict_non_convex=pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for scheme in ['heavy_ball', 'NAG_SC']:\n",
    "    plt.plot(result_dict_non_convex[scheme]['loss_list'], color=color[scheme], label=scheme)\n",
    "\n",
    "plt.legend(fontsize=12)\n",
    "plt.yscale('log')\n",
    "plt.ylabel('U', fontsize=18)\n",
    "plt.xlabel('#iter', fontsize=18)\n",
    "plt.xlim([0, 8000])\n",
    "plt.xticks(fontsize=14, rotation=10)\n",
    "plt.yticks(fontsize=14, rotation=60)\n",
    "plt.savefig('LEV_non_convex.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
