{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import scipy.integrate as integrate\n",
    "import numpy as np\n",
    "sys.path.append('../Optimizers')\n",
    "sys.path.append('..')\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_list=['StiefelSGD_ours', 'StiefelAdam_ours', 'ProjectedStiefelSGD', 'ProjectedStiefelAdam', 'MomentumlessStiefelSGD']\n",
    "\n",
    "transpose_needed=['ProjectedStiefelAdam', 'ProjectedStiefelSGD']\n",
    "\n",
    "legend_dict={'RegularizerStiefelSGD':'Regularizer SGD (Cisse et al)', \n",
    "'RegularizerStiefelAdam':'Regularizer Adam (Cisse et al)',\n",
    "'StiefelSGD_ours':'Stiefel SGD (Ours)', \n",
    "'StiefelAdam_ours':'Stiefel Adam (Ours)', \n",
    "'ProjectedStiefelSGD':'Projected Stiefel SGD (Li et al)', \n",
    "'ProjectedStiefelAdam':'Projected Stiefel Adam (Li et al)',\n",
    " 'MomentumlessStiefelSGD':'Momentumles Stiefel SGD (Wen & Yin)',\n",
    " 'LiCombinedOptimizer' : 'Our retraction + Li algo'}\n",
    "\n",
    " \n",
    "optimizer_dict={}\n",
    "\n",
    "from StiefelRegularizer import RegularizerStiefelSGD\n",
    "optimizer_dict['RegularizerStiefelSGD']=lambda param: RegularizerStiefelSGD(param, lr=1e-3, momentum=0.9, stiefel_regularizer=1)\n",
    "\n",
    "from StiefelRegularizer import RegularizerStiefelAdam\n",
    "optimizer_dict['RegularizerStiefelAdam']=lambda param: RegularizerStiefelAdam(param, lr=1e-3, betas=(0.9, 0.999), stiefel_regularizer=1)\n",
    "\n",
    "from StiefelOptimizers import StiefelSGD\n",
    "optimizer_dict['StiefelSGD_ours']=lambda param: StiefelSGD(param, lr=1e-1, momentum=0.9)\n",
    "\n",
    "from StiefelOptimizers import StiefelAdam\n",
    "optimizer_dict['StiefelAdam_ours']=lambda param: StiefelAdam(param, lr=1e-3, betas=(0.9, 0.999))\n",
    "\n",
    "from ProjectedStiefelOptimizer.stiefel_optimizer import SGDG as ProjectedStiefelSGD\n",
    "optimizer_dict['ProjectedStiefelSGD']=lambda param: ProjectedStiefelSGD(param, lr=2e-1, momentum=0.9, stiefel=True)\n",
    "\n",
    "from ProjectedStiefelOptimizer.stiefel_optimizer import AdamG as ProjectedStiefelAdam\n",
    "optimizer_dict['ProjectedStiefelAdam']=lambda param: ProjectedStiefelAdam(param, lr=5e-0, momentum=0.9, beta2= 0.999, stiefel=True)\n",
    "\n",
    "from MomentumlessStiefelSGD import MomentumlessStiefelSGD\n",
    "optimizer_dict['MomentumlessStiefelSGD']=lambda param: MomentumlessStiefelSGD(param, lr=1e-1)\n",
    "\n",
    "from LiCombinedOptimizer import LiCombinedOptimizer\n",
    "optimizer_dict['LiCombinedOptimizer']=lambda param: LiCombinedOptimizer(param, lr=1e-1, momentum=0.9)\n",
    "\n",
    "device=torch.device('cpu')\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "def lev_problem(n,m, device='cpu', dtype=None):\n",
    "    assert n >= m\n",
    "    if dtype==None:\n",
    "        dtype=torch.get_default_dtype()\n",
    "\n",
    "    A=torch.randn(n, n, device=device, dtype=dtype)\n",
    "    A=(A+A.t())/2/np.sqrt(n)\n",
    "    X_init=torch.zeros(n, m, device=device, dtype=dtype)\n",
    "    torch.nn.init.orthogonal_(X_init)\n",
    "    eig_vals=torch.symeig(A).eigenvalues\n",
    "    eig_vals=eig_vals.sort(descending=True).values\n",
    "    sol=torch.sum(eig_vals[0:m])\n",
    "    return A, X_init, sol\n",
    "\n",
    "\n",
    "def lev_loss(A, X):\n",
    "    return -torch.trace(X.t()@A@X)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convergence and deviation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "torch.manual_seed(0)\n",
    "n=1000\n",
    "m=10\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=5000\n",
    "\n",
    "dev_dict={}\n",
    "loss_dict={}\n",
    "\n",
    "\n",
    "for method in method_list:\n",
    "    print(method)\n",
    "    loss_dict[method]=[]\n",
    "    dev_dict[method]=[]\n",
    "    loss_mem=loss_dict[method]\n",
    "    dev_mem=dev_dict[method]\n",
    "    X=X_init.clone().to(device)\n",
    "    if method in transpose_needed:\n",
    "        Y=X.t()\n",
    "        Y.requires_grad=True\n",
    "        X=Y.t()\n",
    "        optimizer=optimizer_dict[method]([Y])\n",
    "    else:\n",
    "        X.requires_grad=True\n",
    "        optimizer=optimizer_dict[method]([X])\n",
    "    t=time.time()\n",
    "    for i in range(num_iter):\n",
    "        optimizer.zero_grad()\n",
    "        loss=lev_loss(A, X)+sol\n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "        dev_mem.append(dev.item())\n",
    "        loss_mem.append(loss.item())\n",
    "        if i%100==0:\n",
    "            print(loss.item())\n",
    "\n",
    "with open('dev_dict.pkl', 'wb') as handle:\n",
    "    pickle.dump(dev_dict, handle)\n",
    "with open('loss_dict.pkl', 'wb') as handle:\n",
    "    pickle.dump(loss_dict, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time dependent on m with fixed n/m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "p=200\n",
    "m_range=[8,16,32,64,128, 256]\n",
    "nm_ratio=10\n",
    "\n",
    "\n",
    "mean_iter=100\n",
    "\n",
    "time_dict_m={}\n",
    "for method in method_list:\n",
    "    time_dict_m[method]=[None]*len(m_range)\n",
    "    time_mem=time_dict_m[method]\n",
    "    for idx, m in enumerate(m_range):\n",
    "        n=round(nm_ratio*m)\n",
    "        print((method, m))\n",
    "        \n",
    "        A, X_init, sol=lev_problem(n,m, device=device)\n",
    "        X=X_init.clone().to(device)\n",
    "        if method in transpose_needed:\n",
    "            Y=X.t()\n",
    "            Y.requires_grad=True\n",
    "            X=Y.t()\n",
    "            optimizer=optimizer_dict[method]([Y])\n",
    "        else:\n",
    "            X.requires_grad=True\n",
    "            optimizer=optimizer_dict[method]([X])\n",
    "        time_mem[idx]=0\n",
    "        for i in range(mean_iter):\n",
    "            optimizer.zero_grad()\n",
    "            loss=lev_loss(A, X)+sol\n",
    "            loss.backward()\n",
    "            t=time.time()\n",
    "            optimizer.step()\n",
    "            time_mem[idx]+=(time.time()-t)\n",
    "        time_mem[idx]/=mean_iter\n",
    "\n",
    "\n",
    "with open('time_dict_m.pkl', 'wb') as handle:\n",
    "    pickle.dump(time_dict_m, handle)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time dependent on n with fixed m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_num_threads(1)\n",
    "torch.set_default_dtype(torch.float64)\n",
    "device=torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "p=200\n",
    "n_range=[100, 200, 300, 500, 750, 1000, 2000, 3000, 5000]\n",
    "m=10\n",
    "\n",
    "mean_iter=100\n",
    "\n",
    "time_dict_n={}\n",
    "for method in method_list:\n",
    "    time_dict_n[method]=[None]*len(n_range)\n",
    "    time_mem=time_dict_n[method]\n",
    "    for idx, n in enumerate(n_range):\n",
    "        print((method, n))\n",
    "        \n",
    "        A, X_init, sol=lev_problem(n,m, device=device)\n",
    "        X=X_init.clone().to(device)\n",
    "        if method in transpose_needed:\n",
    "            Y=X.t()\n",
    "            Y.requires_grad=True\n",
    "            X=Y.t()\n",
    "            optimizer=optimizer_dict[method]([Y])\n",
    "        else:\n",
    "            X.requires_grad=True\n",
    "            optimizer=optimizer_dict[method]([X])\n",
    "        time_mem[idx]=0\n",
    "        for i in range(mean_iter):\n",
    "            optimizer.zero_grad()\n",
    "            loss=lev_loss(A, X)+sol\n",
    "            loss.backward()\n",
    "\n",
    "            t=time.time()\n",
    "            optimizer.step()\n",
    "            time_mem[idx]+=(time.time()-t)\n",
    "        time_mem[idx]/=mean_iter\n",
    "\n",
    "\n",
    "with open('time_dict_n.pkl', 'wb') as handle:\n",
    "    pickle.dump(time_dict_n, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Different inner product and matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "n=1000\n",
    "m=10\n",
    "p=200\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=5000\n",
    "\n",
    "expm_innerprod_adam_dict={}\n",
    "expm_method_list=['MatrixExp', 'Cayley', 'ForwardEuler']\n",
    "inner_prod_list=['Euclidean', 'Canonical']\n",
    "for expm_method in expm_method_list:\n",
    "    for inner_prod in inner_prod_list:\n",
    "        name=expm_method+'+'+inner_prod\n",
    "        optimizer_func=lambda param: StiefelAdam(param, lr=1e-3, betas=(0.9, 0.999),inner_prod=inner_prod, expm_method=expm_method)\n",
    "        loss_mem=[]\n",
    "        dev_mem=[]\n",
    "        X=X_init.clone().to(device)\n",
    "        X.requires_grad=True\n",
    "        optimizer=optimizer_func([X])\n",
    "        t=time.time()\n",
    "        for i in range(num_iter):\n",
    "            optimizer.zero_grad()\n",
    "            loss=lev_loss(A, X)+sol\n",
    "            loss.backward()\n",
    "            \n",
    "            optimizer.step()\n",
    "            dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "            dev_mem.append(dev.item())\n",
    "            loss_mem.append(loss.item())\n",
    "            if i%100==0:\n",
    "                print(loss.item())\n",
    "            time_comsuming=time.time()-t\n",
    "        expm_innerprod_adam_dict[name]=loss_mem\n",
    "    \n",
    "with open('expm_innerprod_adam_dict.pkl', 'wb') as handle:\n",
    "    pickle.dump(expm_innerprod_adam_dict, handle)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "n=1000\n",
    "m=10\n",
    "p=200\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=5000\n",
    "\n",
    "expm_innerprod_sgd_dict={}\n",
    "expm_method_list=['MatrixExp', 'Cayley', 'ForwardEuler']\n",
    "inner_prod_list=['Euclidean', 'Canonical']\n",
    "for expm_method in expm_method_list:\n",
    "    for inner_prod in inner_prod_list:\n",
    "        name=expm_method+'+'+inner_prod\n",
    "        optimizer_func=lambda param: StiefelSGD(param, lr=1e-1, momentum=0.9,inner_prod=inner_prod, expm_method=expm_method)\n",
    "        loss_mem=[]\n",
    "        dev_mem=[]\n",
    "        X=X_init.clone().to(device)\n",
    "        X.requires_grad=True\n",
    "        optimizer=optimizer_func([X])\n",
    "        t=time.time()\n",
    "        for i in range(num_iter):\n",
    "            optimizer.zero_grad()\n",
    "            loss=lev_loss(A, X)+sol\n",
    "            loss.backward()\n",
    "            \n",
    "            optimizer.step()\n",
    "            dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "            dev_mem.append(dev.item())\n",
    "            loss_mem.append(loss.item())\n",
    "            if i%100==0:\n",
    "                print(loss.item())\n",
    "            time_comsuming=time.time()-t\n",
    "        expm_innerprod_sgd_dict[name]=loss_mem\n",
    "    \n",
    "with open('expm_innerprod_sgd_dict.pkl', 'wb') as handle:\n",
    "    pickle.dump(expm_innerprod_sgd_dict, handle)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Li's Projected Stiefel SGD with and Our retraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_list=['LiCombinedOptimizer', 'ProjectedStiefelSGD', 'StiefelSGD_ours']\n",
    "\n",
    "\n",
    "torch.manual_seed(0)\n",
    "n=1000\n",
    "m=10\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=2000\n",
    "\n",
    "loss_dict={}\n",
    "dev_dict={}\n",
    "\n",
    "for method in method_list:\n",
    "    print(method)\n",
    "    loss_dict[method]=[]\n",
    "    dev_dict[method]=[]\n",
    "    loss_mem=loss_dict[method]\n",
    "    dev_mem=dev_dict[method]\n",
    "    X=X_init.clone().to(device)\n",
    "    if method in transpose_needed:\n",
    "        Y=X.t()\n",
    "        Y.requires_grad=True\n",
    "        X=Y.t()\n",
    "        optimizer=optimizer_dict[method]([Y])\n",
    "    else:\n",
    "        X.requires_grad=True\n",
    "        optimizer=optimizer_dict[method]([X])\n",
    "    t=time.time()\n",
    "    for i in range(num_iter):\n",
    "        optimizer.zero_grad()\n",
    "        loss=lev_loss(A, X)+sol\n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "        dev_mem.append(dev.item())\n",
    "        loss_mem.append(loss.item())\n",
    "        if i%100==0:\n",
    "            print(loss.item())\n",
    "\n",
    "with open('dev_dict_LiCombined.pkl', 'wb') as handle:\n",
    "    pickle.dump(dev_dict, handle)\n",
    "with open('loss_dict_LiCombined.pkl', 'wb') as handle:\n",
    "    pickle.dump(loss_dict, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Number for inner loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "torch.manual_seed(0)\n",
    "n=100\n",
    "m=10\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=2000\n",
    "\n",
    "dev_loop_dict={}\n",
    "loss_loop_dict={}\n",
    "\n",
    "method_list=['ProjectedStiefelSGD']\n",
    "'''\n",
    "loop_list = [2]\n",
    "qr_every_list = [1,2,4,8,16]\n",
    "'''\n",
    "loop_list = [1,2,4,6, 8,16]\n",
    "qr_every_list = [int(1e6)]\n",
    "\n",
    "for method in method_list:\n",
    "    for loop_num in loop_list:\n",
    "        for qr_every in qr_every_list:\n",
    "            \n",
    "            dev_dict = dev_loop_dict[str((loop_num, qr_every))] = {}\n",
    "            loss_dict = loss_loop_dict[str((loop_num, qr_every))] = {}\n",
    "            print(method)\n",
    "            loss_dict[method]=[]\n",
    "            dev_dict[method]=[]\n",
    "            loss_mem=loss_dict[method]\n",
    "            dev_mem=dev_dict[method]\n",
    "            X=X_init.clone().to(device)\n",
    "            if method in transpose_needed:\n",
    "                Y=X.t()\n",
    "                Y.requires_grad=True\n",
    "                X=Y.t()\n",
    "                optimizer=optimizer_dict[method]([Y])\n",
    "            else:\n",
    "                X.requires_grad=True\n",
    "                optimizer=optimizer_dict[method]([X])\n",
    "\n",
    "            optimizer.param_groups[0]['QR_every'] = qr_every\n",
    "            optimizer.param_groups[0]['Cayley_loop_num'] = loop_num\n",
    "            t=time.time()\n",
    "            for i in range(num_iter):\n",
    "                optimizer.zero_grad()\n",
    "                loss=lev_loss(A, X)+sol\n",
    "                loss.backward()\n",
    "                \n",
    "                optimizer.step()\n",
    "                dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "                dev_mem.append(dev.item())\n",
    "                loss_mem.append(loss.item())\n",
    "                if i%100==0:\n",
    "                    print(loss.item())\n",
    "with open('dev_loop_dict_Li.pkl', 'wb') as handle:\n",
    "    pickle.dump(dev_loop_dict, handle)\n",
    "with open('loss_loop_dict_Li.pkl', 'wb') as handle:\n",
    "    pickle.dump(loss_loop_dict, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# QR frequency is import in projected Stiefel SGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "torch.manual_seed(0)\n",
    "n=100\n",
    "m=10\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=1500\n",
    "\n",
    "dev_loop_dict={}\n",
    "loss_loop_dict={}\n",
    "\n",
    "method_list=['ProjectedStiefelSGD']\n",
    "\n",
    "loop_list = [5]\n",
    "qr_every_list = [1,2]\n",
    "'''\n",
    "loop_list = [1,2,4,6, 8,16]\n",
    "qr_every_list = [1e6]\n",
    "'''\n",
    "for method in method_list:\n",
    "    for loop_num in loop_list:\n",
    "        for qr_every in qr_every_list:\n",
    "            \n",
    "            dev_dict = dev_loop_dict[str((loop_num, qr_every))] = {}\n",
    "            loss_dict = loss_loop_dict[str((loop_num, qr_every))] = {}\n",
    "            print(method)\n",
    "            loss_dict[method]=[]\n",
    "            dev_dict[method]=[]\n",
    "            loss_mem=loss_dict[method]\n",
    "            dev_mem=dev_dict[method]\n",
    "            X=X_init.clone().to(device)\n",
    "            if method in transpose_needed:\n",
    "                Y=X.t()\n",
    "                Y.requires_grad=True\n",
    "                X=Y.t()\n",
    "                optimizer=optimizer_dict[method]([Y])\n",
    "            else:\n",
    "                X.requires_grad=True\n",
    "                optimizer=optimizer_dict[method]([X])\n",
    "\n",
    "            optimizer.param_groups[0]['QR_every'] = qr_every\n",
    "            optimizer.param_groups[0]['Cayley_loop_num'] = loop_num\n",
    "            t=time.time()\n",
    "            for i in range(num_iter):\n",
    "                optimizer.zero_grad()\n",
    "                loss=lev_loss(A, X)+sol\n",
    "                loss.backward()\n",
    "                \n",
    "                optimizer.step()\n",
    "                dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "                dev_mem.append(dev.item())\n",
    "                loss_mem.append(loss.item())\n",
    "                if i%100==0:\n",
    "                    print(loss.item())\n",
    "\n",
    "\n",
    "method ='StiefelSGD_ours'\n",
    "            \n",
    "dev_dict = dev_loop_dict[method] = {}\n",
    "loss_dict = loss_loop_dict[method] = {}\n",
    "print(method)\n",
    "loss_dict[method]=[]\n",
    "dev_dict[method]=[]\n",
    "loss_mem=loss_dict[method]\n",
    "dev_mem=dev_dict[method]\n",
    "X=X_init.clone().to(device)\n",
    "if method in transpose_needed:\n",
    "    Y=X.t()\n",
    "    Y.requires_grad=True\n",
    "    X=Y.t()\n",
    "    optimizer=optimizer_dict[method]([Y])\n",
    "else:\n",
    "    X.requires_grad=True\n",
    "    optimizer=optimizer_dict[method]([X])\n",
    "\n",
    "\n",
    "t=time.time()\n",
    "for i in range(num_iter):\n",
    "    optimizer.zero_grad()\n",
    "    loss=lev_loss(A, X)+sol\n",
    "    loss.backward()\n",
    "    \n",
    "    optimizer.step()\n",
    "    dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "    dev_mem.append(dev.item())\n",
    "    loss_mem.append(loss.item())\n",
    "    if i%100==0:\n",
    "        print(loss.item())\n",
    "\n",
    "\n",
    "with open('projected_qr_dev.pkl', 'wb') as handle:\n",
    "    pickle.dump(dev_loop_dict, handle)\n",
    "with open('projected_qr_loss.pkl', 'wb') as handle:\n",
    "    pickle.dump(loss_loop_dict, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convergence and deviation from manifold under different number of inner loop in our Stiefel SGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "method_list=['StiefelSGD_ours']\n",
    "\n",
    "torch.manual_seed(0)\n",
    "n=1000\n",
    "m=10\n",
    "\n",
    "A, X_init, sol=lev_problem(n,m, device=device)\n",
    "\n",
    "num_iter=1500\n",
    "\n",
    "dev_loop_dict={}\n",
    "loss_loop_dict={}\n",
    "\n",
    "loop_list = [1,2, 3,4, 5, 6, 7, 8, 100]\n",
    "\n",
    "for loop_num in loop_list:\n",
    "        \n",
    "    dev_dict = dev_loop_dict[str(loop_num)] = {}\n",
    "    loss_dict = loss_loop_dict[str(loop_num)] ={}\n",
    "    for method in method_list:\n",
    "        print(method)\n",
    "        loss_dict[method]=[]\n",
    "        dev_dict[method]=[]\n",
    "        loss_mem=loss_dict[method]\n",
    "        dev_mem=dev_dict[method]\n",
    "        X=X_init.clone().to(device)\n",
    "        if method in transpose_needed:\n",
    "            Y=X.t()\n",
    "            Y.requires_grad=True\n",
    "            X=Y.t()\n",
    "            optimizer=optimizer_dict[method]([Y])\n",
    "        else:\n",
    "            X.requires_grad=True\n",
    "            optimizer=optimizer_dict[method]([X])\n",
    "        optimizer.param_groups[0]['max_inner_iter'] = loop_num\n",
    "        t=time.time()\n",
    "        for i in range(num_iter):\n",
    "            optimizer.zero_grad()\n",
    "            loss=lev_loss(A, X)+sol\n",
    "            loss.backward()\n",
    "            \n",
    "            optimizer.step()\n",
    "            dev=torch.norm(X.t()@X-torch.eye(m).to(device))\n",
    "            dev_mem.append(dev.item())\n",
    "            loss_mem.append(loss.item())\n",
    "            if i%100==0:\n",
    "                print(loss.item())\n",
    "\n",
    "with open('dev_loop_dict_ours.pkl', 'wb') as handle:\n",
    "    pickle.dump(dev_loop_dict, handle)\n",
    "with open('loss_loop_dict_ours.pkl', 'wb') as handle:\n",
    "    pickle.dump(loss_loop_dict, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open('expm_innerprod_sgd_dict.pkl', 'rb') as handle:\n",
    "    expm_innerprod_sgd_dict = pickle.load(handle)\n",
    "\n",
    "\n",
    "label_dict={'MatrixExp': 'Matrix Exp', 'Cayley':'Cayley map', 'ForwardEuler':'Forward Euler', 'Euclidean':'Euclidean', 'Canonical':'Canonical'}\n",
    "\n",
    "\n",
    "for expm_method in expm_method_list:\n",
    "    for inner_prod in inner_prod_list:\n",
    "        name=label_dict[expm_method]+'+'+label_dict[inner_prod]\n",
    "        plt.plot(np.abs(expm_innerprod_sgd_dict[expm_method+'+'+inner_prod]), label=name)\n",
    "plt.xlabel('iter')\n",
    "plt.ylabel('loss')\n",
    "plt.yscale('log')\n",
    "# plt.title('Convergence for inner prod and expm approx (Stiefel Adam)')\n",
    "# plt.title('Convergence for inner prod and expm approx (Stiefel SGD)')\n",
    "# plt.text(0.25, 0.2,'Convergence for inner prod \\n and expm approx (Stiefel SGD)',\n",
    "#      horizontalalignment='left',\n",
    "#      verticalalignment='bottom',\n",
    "#      transform = ax.transAxes, \n",
    "#      size=13)\n",
    "plt.legend()\n",
    "# plt.savefig('./lev_compare_Adam.pdf', bbox_inches='tight')\n",
    "plt.savefig('./lev_compare_SGD.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open('expm_innerprod_adam_dict.pkl', 'rb') as handle:\n",
    "    expm_innerprod_adam_dict = pickle.load(handle)\n",
    "\n",
    "label_dict={'MatrixExp': 'Matrix Exp', 'Cayley':'Cayley map', 'ForwardEuler':'Forward Euler', 'Euclidean':'Euclidean', 'Canonical':'Canonical'}\n",
    "for expm_method in expm_method_list:\n",
    "    for inner_prod in inner_prod_list:\n",
    "        name=label_dict[expm_method]+'+'+label_dict[inner_prod]\n",
    "        plt.plot(np.abs(expm_innerprod_adam_dict[expm_method+'+'+inner_prod]), label=name)\n",
    "plt.xlabel('iter')\n",
    "plt.ylabel('loss')\n",
    "plt.yscale('log')\n",
    "# plt.title('Convergence for inner prod and expm approx (Stiefel Adam)')\n",
    "# plt.title('Convergence for inner prod and expm approx (Stiefel SGD)')\n",
    "# plt.text(0.05, 0.05,'Convergence for inner \\n prod and expm \\n approx (Stiefel Adam)',\n",
    "#      horizontalalignment='left',\n",
    "#      verticalalignment='bottom',\n",
    "     # transform = ax.transAxes, \n",
    "     # size=13)\n",
    "plt.legend()\n",
    "plt.savefig('./lev_compare_Adam.pdf', bbox_inches='tight')\n",
    "# plt.savefig('./lev_compare_SGD.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dev_dict.pkl', 'rb') as handle:\n",
    "    dev_dict = pickle.load(handle)\n",
    "with open('loss_dict.pkl', 'rb') as handle:\n",
    "    loss_dict = pickle.load(handle)\n",
    "\n",
    "method_list=['StiefelSGD_ours', 'StiefelAdam_ours', 'ProjectedStiefelSGD', 'ProjectedStiefelAdam', 'MomentumlessStiefelSGD']\n",
    "fig, ax = plt.subplots()\n",
    "for method in method_list:\n",
    "    plt.plot(loss_dict[method], label=legend_dict[method])\n",
    "ax.set_xlabel('iter')\n",
    "ax.set_ylabel('loss')\n",
    "plt.yscale('log')\n",
    "# plt.title('Convergence')\n",
    "# plt.legend()\n",
    "plt.text(0.15, 0.85,'Convergence',\n",
    "     horizontalalignment='left',\n",
    "     verticalalignment='bottom',\n",
    "     transform = ax.transAxes, \n",
    "     size=13)\n",
    "plt.savefig('./lev_convergence.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "for method in method_list:\n",
    "    plt.plot(dev_dict[method], label=legend_dict[method])\n",
    "plt.xlabel('iter')\n",
    "plt.ylabel('deviation')\n",
    "plt.yscale('log')\n",
    "# plt.title('Deviation from manifold')\n",
    "# plt.legend()\n",
    "plt.text(0.15, 0.85,'Manifold preservance',\n",
    "     horizontalalignment='left',\n",
    "     verticalalignment='bottom',\n",
    "     transform = ax.transAxes, \n",
    "     size=13)\n",
    "plt.savefig('./lev_dev.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import pylab\n",
    "figlegend = pylab.figure(figsize=(3,2))\n",
    "figlegend.legend(ax.get_legend_handles_labels()[0], ax.get_legend_handles_labels()[1])\n",
    "\n",
    "figlegend.savefig('lev_legend.pdf', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('time_dict_m.pkl', 'rb') as handle:\n",
    "    time_dict_m = pickle.load(handle)\n",
    "\n",
    "for method in method_list:\n",
    "    plt.plot(m_range, time_dict_m[method], label=legend_dict[method])\n",
    "plt.xlabel('m (fix n/m=10)')\n",
    "plt.ylabel('time consuming on CPU (s per iter)')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "# plt.title('Time dependent on m (fix n/m) (log scale)')\n",
    "# plt.legend()\n",
    "plt.text(0.45, 0.05,'Time complexity against m \\n (fix n/m=10)',\n",
    "     horizontalalignment='left',\n",
    "     verticalalignment='bottom',\n",
    "     transform = ax.transAxes, \n",
    "     size=13)\n",
    "plt.savefig('./lev_time_m_log.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('time_dict_n.pkl', 'rb') as handle:\n",
    "    time_dict_n = pickle.load(handle)\n",
    "\n",
    "for method in method_list:\n",
    "    plt.plot(n_range, time_dict_n[method], label=legend_dict[method])\n",
    "plt.xlabel('n (fix m=10)')\n",
    "plt.ylabel('time consuming on CPU')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "# plt.title('Time dependent on n (fix m) (log scale)')\n",
    "# plt.legend()\n",
    "plt.text(0.05, 0.85,'Time complexity against n \\n (fix m=10)',\n",
    "     horizontalalignment='left',\n",
    "     verticalalignment='bottom',\n",
    "     transform = ax.transAxes, \n",
    "     size=13)\n",
    "plt.savefig('./lev_time_n_log.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(x, w):\n",
    "    return np.convolve(x, np.ones(w), 'valid') / w\n",
    "window_width=10\n",
    "\n",
    "\n",
    "with open('projected_qr_dev.pkl', 'rb') as handle:\n",
    "    dev_loop_dict = pickle.load(handle)\n",
    "with open('projected_qr_loss.pkl', 'rb') as handle:\n",
    "    loss_loop_dict = pickle.load(handle)\n",
    "\n",
    "method_list=['ProjectedStiefelSGD']\n",
    "\n",
    "loop_list = [5]\n",
    "qr_every_list = [1, 2]\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "for loop_num in loop_list:\n",
    "    for qr_every in qr_every_list:\n",
    "        \n",
    "        dev_dict = dev_loop_dict[str((loop_num, qr_every))]\n",
    "        loss_dict = loss_loop_dict[str((loop_num, qr_every))]\n",
    "        for method in method_list:\n",
    "            plt.plot(moving_average(np.abs(loss_dict[method]), window_width), label='Cayley iter '+str(loop_num))\n",
    "plt.plot(moving_average(np.abs(loss_loop_dict['StiefelSGD_ours']['StiefelSGD_ours']), window_width), label='Cayley iter '+str(loop_num))\n",
    "ax.set_xlabel('iter')\n",
    "ax.set_ylabel('loss (abs val)')\n",
    "plt.yscale('log')\n",
    "plt.title('Convergence')\n",
    "plt.legend()\n",
    "plt.savefig('./cayley_iter_convergence.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "for loop_num in loop_list:\n",
    "    for qr_every in qr_every_list:\n",
    "        dev_dict = dev_loop_dict[str((loop_num, qr_every))]\n",
    "        loss_dict = loss_loop_dict[str((loop_num, qr_every))]\n",
    "        for method in method_list:\n",
    "            plt.plot(moving_average(dev_dict[method], window_width), label='Cayley iter '+str(loop_num))\n",
    "plt.plot(moving_average(np.abs(dev_loop_dict['StiefelSGD_ours']['StiefelSGD_ours']), window_width), label='Cayley iter '+str(loop_num))\n",
    "\n",
    "plt.xlabel('iter')\n",
    "plt.ylabel('deviation')\n",
    "plt.yscale('log')\n",
    "plt.title('Deviation from manifold')\n",
    "plt.legend()\n",
    "plt.savefig('./cayley_iter_dev.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(x, w):\n",
    "    return np.convolve(x, np.ones(w), 'valid') / w\n",
    "window_width=1\n",
    "\n",
    "\n",
    "with open('dev_dict_LiCombined.pkl', 'rb') as handle:\n",
    "    dev_dict = pickle.load(handle)\n",
    "with open('loss_dict_LiCombined.pkl', 'rb') as handle:\n",
    "    loss_dict = pickle.load(handle)\n",
    "\n",
    "method_list=['LiCombinedOptimizer', 'ProjectedStiefelSGD', 'StiefelSGD_ours']\n",
    "\n",
    "plt.figure(figsize=(15,5))\n",
    "ax1 = plt.subplot(1, 2, 1)\n",
    "\n",
    "for method in method_list:\n",
    "    ax1.plot(moving_average(np.abs(loss_dict[method]), window_width), label=legend_dict[method])\n",
    "ax1.set_xlabel('iter')\n",
    "ax1.set_ylabel('loss')\n",
    "plt.yscale('log')\n",
    "plt.title('Convergence')\n",
    "plt.legend()\n",
    "# plt.savefig('./LiWithOurRetraction_convergence.jpg', bbox_inches='tight')\n",
    "# plt.show()\n",
    "ax2 = plt.subplot(1, 2, 2)\n",
    "\n",
    "for method in method_list:\n",
    "    ax2.plot(moving_average(dev_dict[method], window_width), label=legend_dict[method])\n",
    "ax2.set_xlabel('iter')\n",
    "ax2.set_ylabel('deviation')\n",
    "ax2.set_yscale('log')\n",
    "ax2.set_title('Manifold preservance')\n",
    "plt.legend()\n",
    "plt.savefig('./LiWithOurRetraction.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(x, w):\n",
    "    return np.convolve(x, np.ones(w), 'valid') / w\n",
    "window_width=10\n",
    "\n",
    "\n",
    "with open('dev_loop_dict_ours.pkl', 'rb') as handle:\n",
    "    dev_loop_dict = pickle.load(handle)\n",
    "with open('loss_loop_dict_ours.pkl', 'rb') as handle:\n",
    "    loss_loop_dict = pickle.load(handle)\n",
    "\n",
    "method_list=['StiefelSGD_ours']\n",
    "loop_list = ['1', '2', '3', '4', '5', '6', '7', '8', '100']\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "for loop_num in loop_list:\n",
    "        \n",
    "        dev_dict = dev_loop_dict[str(loop_num)]\n",
    "        loss_dict = loss_loop_dict[str(loop_num)]\n",
    "        for method in method_list:\n",
    "            plt.plot(moving_average(np.abs(loss_dict[method]), window_width), label='mat root inv '+str(loop_num)+' iter')\n",
    "ax.set_xlabel('iter')\n",
    "ax.set_ylabel('loss')\n",
    "plt.yscale('log')\n",
    "plt.title('Convergence')\n",
    "plt.legend()\n",
    "plt.savefig('./ours_inner_iter_convergence.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "for loop_num in loop_list:\n",
    "    dev_dict = dev_loop_dict[str(loop_num)]\n",
    "    loss_dict = loss_loop_dict[str(loop_num)]\n",
    "    for method in method_list:\n",
    "        plt.plot(moving_average(dev_dict[method], window_width), label='mat root inv '+str(loop_num)+' iter')\n",
    "plt.xlabel('iter')\n",
    "plt.ylabel('deviation')\n",
    "plt.yscale('log')\n",
    "plt.title('Deviation from manifold')\n",
    "plt.legend()\n",
    "plt.savefig('./ours_inner_iter_dev.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(x, w):\n",
    "    return np.convolve(x, np.ones(w), 'valid') / w\n",
    "window_width=10\n",
    "\n",
    "\n",
    "with open('dev_loop_dict_Li.pkl', 'rb') as handle:\n",
    "    dev_loop_dict = pickle.load(handle)\n",
    "with open('loss_loop_dict_Li.pkl', 'rb') as handle:\n",
    "    loss_loop_dict = pickle.load(handle)\n",
    "\n",
    "method_list=['ProjectedStiefelSGD']\n",
    "\n",
    "loop_list = [1,2,4,8,16]\n",
    "qr_every_list = [int(1e6)]\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "for loop_num in loop_list:\n",
    "    for qr_every in qr_every_list:\n",
    "        \n",
    "        dev_dict = dev_loop_dict[str((loop_num, qr_every))]\n",
    "        loss_dict = loss_loop_dict[str((loop_num, qr_every))]\n",
    "        for method in method_list:\n",
    "            plt.plot(moving_average(np.abs(loss_dict[method]), window_width), label='Cayley iter '+str(loop_num))\n",
    "ax.set_xlabel('iter')\n",
    "ax.set_ylabel('loss (abs val)')\n",
    "plt.yscale('log')\n",
    "plt.title('Convergence')\n",
    "plt.legend()\n",
    "plt.savefig('./cayley_iter_convergence.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "for loop_num in loop_list:\n",
    "    for qr_every in qr_every_list:\n",
    "        dev_dict = dev_loop_dict[str((loop_num, qr_every))]\n",
    "        loss_dict = loss_loop_dict[str((loop_num, qr_every))]\n",
    "        for method in method_list:\n",
    "            plt.plot(moving_average(dev_dict[method], window_width), label='Cayley iter '+str(loop_num))\n",
    "plt.xlabel('iter')\n",
    "plt.ylabel('deviation')\n",
    "plt.yscale('log')\n",
    "plt.title('Deviation from manifold')\n",
    "plt.legend()\n",
    "plt.savefig('./cayley_iter_dev.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "26de051ba29f2982a8de78e945f0abaf191376122a1563185a90213a26c5da77"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
