{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1152,
   "id": "29ba15ff-b628-4f24-9835-5d8c1d9cd49c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "import pickle\n",
    "from scipy.stats import pearsonr\n",
    "from matplotlib.ticker import LogLocator, AutoMinorLocator\n",
    "\n",
    "def plot_optimization(opt_method_list, dataset, lr, model_type, epoch_num):\n",
    "    x_right = epoch_num\n",
    "    for opt_method in opt_method_list:   \n",
    "        if opt_method == 'Adagrad':\n",
    "            y_ticks = [2.5,2,1.3]\n",
    "        elif opt_method == 'SGD':\n",
    "            y_ticks = [10,5]\n",
    "        elif opt_method == 'momentum':\n",
    "            y_ticks = [2.5,2,1.3]\n",
    "        elif opt_method == 'RMSprop':\n",
    "            y_ticks = [1.5,1,0.5]\n",
    "        elif opt_method == 'Adam':\n",
    "            y_ticks = [10,5]\n",
    "        elif opt_method == 'Adamax':\n",
    "            y_ticks = [1.5,1,0.5]\n",
    "        elif opt_method == 'CustomAdam':\n",
    "            y_ticks = [1.5,1,0.5]\n",
    "        elif opt_method == 'AdamW':\n",
    "            y_ticks = [1.5,1,0.5]\n",
    "                \n",
    "        if dataset == \"MNIST\":\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [2, 0.8,0.3,0.1, 0.03]\n",
    "                # y_ticks = [1,0.6,0.3, 0.16,0.08]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [5,1.6,0.5,0.5,0.13,0.03]\n",
    "                # y_ticks = [4.0,1.6,0.7,0.3,0.1]\n",
    "                \n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [5,1.6,0.5,0.5,0.13,0.03]\n",
    "                # y_ticks = [4.0,1.6,0.7,0.3,0.1]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [5,1.9,0.7,0.25,0.1]\n",
    "                # y_ticks = [4, 0.9,0.3,0.1,0.03]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "                \n",
    "        if dataset == \"FashionMNIST\":\n",
    "            if opt_method == 'Adagrad':\n",
    "                \n",
    "                y_ticks = [1,0.7,0.5,0.36,0.25]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = y_ticks = [4, 2,1,0.5,0.25]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [4, 2,1,0.5,0.25]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [2,1.1,0.7,0.4,0.25]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [1.5,1,0.5]\n",
    "\n",
    "        \n",
    "        if dataset == 'CIFAR100':\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [5,4.2,3.6,3,2.5]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [5,3.6,2.7,2,1.5]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [5,3.6,2.7,2,1.5]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [5,4.2,3.6,3.1,2.6]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "\n",
    "        if dataset == 'CIFAR10':\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [2.3,1.9,1.5,1.2,1.0]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [3.5,2.1,1.3,0.8,0.5]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [3.5,2.1,1.3,0.8,0.5]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [2.6,2.2,1.8,1.5,1.2]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "\n",
    "        if dataset == 'Imagenet':\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [5.5,3.3,3.9,4.6,2.8]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [6,4.3,3.1,2.2,1.6]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [6,4.3,3.1,2.2,1.6]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [6,3.4,4.1,5, 2.8]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "\n",
    "        if dataset == 'electricity':\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [0.004,0.008, 0.015,0.03,0.05]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [0.02,0.05,0.11,0.25, 0.6]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [0.02,0.05,0.11,0.25, 0.6]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [0.004,0.015, 0.05,0.16,0.5]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "                \n",
    "        if dataset == 'traffic':\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [0.003,0.0053, 0.0095,0.017,0.03]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [0.01,0.027,0.08,0.25, 0.6]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [0.01,0.027,0.08,0.25, 0.6]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [0.003,0.009,0.028,0.085,0.25]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "                \n",
    "        if dataset == 'PennTree':\n",
    "            if opt_method == 'Adagrad':\n",
    "                y_ticks = [4.5,4.9,5.4,5.9,6.5]\n",
    "            elif opt_method == 'SGD':\n",
    "                y_ticks = [5.6, 6.3, 7.1, 8, 9]\n",
    "            elif opt_method == 'momentum':\n",
    "                y_ticks = [5.6, 6.3, 7.1, 8, 9]\n",
    "            elif opt_method == 'RMSprop':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'Adam':\n",
    "                y_ticks = [4.8, 5.6, 6.5, 7.5, 9]\n",
    "            elif opt_method == 'Adamax':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'CustomAdam':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "            elif opt_method == 'AdamW':\n",
    "                y_ticks = [4.5,4,3.3]\n",
    "\n",
    "        plt.clf()\n",
    "\n",
    "        plt.figure()\n",
    "        train_loss_list = []\n",
    "        valid_loss_list = []\n",
    "        time_list = []\n",
    "        train_loss_teleport_list = []\n",
    "        valid_loss_teleport_list = []\n",
    "        time_teleport_list = []\n",
    "        for run_num in range(3):\n",
    "            with open('logs/optimization_final/{}/archive/{}_{}_lr_{}_{}_{}.pkl'.format(dataset, dataset, opt_method, lr, model_type, run_num), 'rb') as f:\n",
    "                # loss_arr_SGD, valid_loss_SGD, _, _, time_SGD, _ = pickle.load(f)\n",
    "                loss_arr_SGD, valid_loss_SGD, _, time_SGD, _ = pickle.load(f)\n",
    "            with open('logs/optimization_final/{}/archive/{}_{}_lr_{}_{}_teleport_{}.pkl'.format(dataset, dataset, opt_method, lr, model_type, run_num), 'rb') as f:\n",
    "                # loss_arr_teleport, valid_loss_teleport, _, _, time_teleport, _ = pickle.load(f)\n",
    "                loss_arr_teleport, valid_loss_teleport, _, time_teleport, _ = pickle.load(f)\n",
    "            train_loss_list.append(loss_arr_SGD)\n",
    "            valid_loss_list.append(valid_loss_SGD)\n",
    "            time_list.append(time_SGD)\n",
    "            time_teleport_list.append(time_teleport)\n",
    "\n",
    "            train_loss_teleport_list.append(loss_arr_teleport)\n",
    "            valid_loss_teleport_list.append(valid_loss_teleport)\n",
    "\n",
    "        time_mean = np.mean(time_list, axis=0)\n",
    "        time_teleport_mean = np.mean(time_teleport_list, axis=0)\n",
    "\n",
    "        train_loss_teleport_mean = np.mean(train_loss_teleport_list, axis=0)\n",
    "        train_loss_teleport_std = np.std(train_loss_teleport_list, axis=0)\n",
    "        valid_loss_teleport_mean = np.mean(valid_loss_teleport_list, axis=0)\n",
    "        valid_loss_teleport_std = np.std(valid_loss_teleport_list, axis=0)\n",
    "\n",
    "        train_loss_SGD_mean = np.mean(train_loss_list, axis=0)\n",
    "        train_loss_SGD_std = np.std(train_loss_list, axis=0)\n",
    "        valid_loss_SGD_mean = np.mean(valid_loss_list, axis=0)\n",
    "        valid_loss_SGD_std = np.std(valid_loss_list, axis=0)\n",
    "\n",
    "        plt.figure(figsize=(10, 7.5), facecolor='white')\n",
    "        # x = np.arange(len(train_loss_SGD_mean))*2000\n",
    "        plt.plot(train_loss_SGD_mean[:x_right], '--', linewidth=5, color='#1f77b4', label='{} train'.format(opt_method))\n",
    "        plt.plot(valid_loss_SGD_mean[:x_right], '-', linewidth=5, color='#1f77b4', label='{} test'.format(opt_method))\n",
    "        plt.plot(train_loss_teleport_mean[:x_right], '--', linewidth=5, color='#ff7f0e', label='{}+teleport train'.format(opt_method))\n",
    "        plt.plot(valid_loss_teleport_mean[:x_right], '-', linewidth=5, color='#ff7f0e', label='{}+teleport test'.format(opt_method))\n",
    "        plt.ylim(2,4) # mnist mlp sgd momentum\n",
    "        plt.ylim(0.08,1) # mnist mlp adagrad\n",
    "        plt.ylim(0.03,1) # mnist mlp adam\n",
    "        plt.ylim(0.25,1) # Fashionmnist mlp sgd momentum adagrad adam\n",
    "        plt.ylim(0.5,3.5) # cifar10 cnn sgd momentum\n",
    "        plt.ylim(1,2) # cifar10 cnn adagrad\n",
    "        plt.ylim(1.2,2.6) # cifar10 cnn adam\n",
    "        plt.ylim(1.5,5) # cifar100 cnn sgd momentum\n",
    "        plt.ylim(2.5,4.6) # cifar100 cnn adagrad\n",
    "        plt.ylim(2.6,5) # cifar100 cnn adam\n",
    "        # plt.ylim(1.6,6) # Imagenet cnn SGD\n",
    "        # plt.ylim(2.8,5.5) # Imagenet cnn Adagrad\n",
    "        # # plt.ylim(2.8,6) # Imagenet cnn adam\n",
    "        plt.ylim(0.03,5) # MNIST transformer sgd momentum \n",
    "        plt.ylim(0.03,2) # MNIST transformer adagrad\n",
    "        plt.ylim(0.1,2) # MNIST transformer adam\n",
    "        plt.ylim(0.02,0.6) # electrcity transformer sgd momentum\n",
    "        plt.ylim(0.01,0.05) # electrcity transformer adagrad\n",
    "        plt.ylim(0.01,0.1) # electrcity transformer adam\n",
    "        plt.ylim(0.01,0.6) # traffic transformer sgd momentum\n",
    "        plt.ylim(0.003,0.03) # traffic transformer adagrad\n",
    "        plt.ylim(0.003,0.25) # traffic transformer adam\n",
    "        # plt.ylim(5.6,9) # Pentree transformer sgd momentum\n",
    "        # plt.ylim(4.5,6.5) # Pentree transformer adagrad\n",
    "        # plt.ylim(4.8,6.5) # Pentree transformer adam\n",
    "        \n",
    "\n",
    "        N = len(train_loss_SGD_mean)\n",
    "        print(N)\n",
    "        plt.fill_between(np.arange(N), \\\n",
    "                        valid_loss_SGD_mean[:x_right] - valid_loss_SGD_std[:x_right], \\\n",
    "                        valid_loss_SGD_mean[:x_right] + valid_loss_SGD_std[:x_right], \\\n",
    "                        color='#1f77b4', alpha=0.5)\n",
    "        plt.fill_between(np.arange(N), \\\n",
    "                        valid_loss_teleport_mean[:x_right] - valid_loss_teleport_std[:x_right], \\\n",
    "                        valid_loss_teleport_mean[:x_right] + valid_loss_teleport_std[:x_right], \\\n",
    "                        color='#ff7f0e', alpha=0.5)\n",
    "        # plt.fill_between(np.arange(N), \\\n",
    "        #                 train_loss_SGD_mean[:x_right] - train_loss_SGD_std[:x_right], \\\n",
    "        #                 train_loss_SGD_mean[:x_right] + train_loss_SGD_std[:x_right], \\\n",
    "        #                 color='#1f77b4', alpha=0.5)\n",
    "        # plt.fill_between(np.arange(N), \\\n",
    "        #                 train_loss_teleport_mean[:x_right] - train_loss_teleport_std[:x_right], \\\n",
    "        #                 train_loss_teleport_mean[:x_right] + train_loss_teleport_std[:x_right], \\\n",
    "        #                 color='#ff7f0e', alpha=0.5)\n",
    "\n",
    "        plt.style.use('classic')\n",
    "        plt.xlabel('Epoch', fontsize=28)\n",
    "        # plt.xlabel('Step', fontsize=28)\n",
    "        plt.ylabel('Loss', fontsize=28)\n",
    "        plt.yscale('log')\n",
    "        \n",
    "        \n",
    "        plt.minorticks_off()\n",
    "        \n",
    "        print(valid_loss_SGD_std[:x_right])\n",
    "        # Enable minor ticks and set grid lines\n",
    "        plt.grid(which='both', linestyle='-', linewidth=0.5, alpha=0.4)\n",
    "        \n",
    "        # plt.xticks([0,20*2000,40*2000,60*2000,80*2000,100*2000],['0', '40', '80', '120', '160', '200k'], fontsize= 22)\n",
    "        # plt.xticks([0,20,40,60,80,100], fontsize= 22)\n",
    "        # plt.xticks([0,30,60,90,120,150], fontsize= 22)\n",
    "        \n",
    "        # plt.xticks([0,40,80,120,160, 200], fontsize= 22)\n",
    "        # plt.xticks([0,60,120,180,240, 300], fontsize= 22)\n",
    "        # plt.xticks([0,80,160,240,320, 400], fontsize= 22)\n",
    "        plt.xticks([0,10,20,30,40,50], fontsize= 22)\n",
    "        plt.yticks(y_ticks, y_ticks, fontsize= 22)\n",
    "        plt.legend(fontsize=28, frameon=True, shadow=True, fancybox=True)\n",
    "        plt.title('traffic', fontsize=28)\n",
    "        plt.tight_layout()\n",
    "\n",
    "        plt.savefig('figures/optimization_final/{}_{}_{}_loss_vs_epoch.pdf'.format(dataset, opt_method, model_type), bbox_inches='tight', dpi = 1000)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "35492739-918c-4c69-b9a4-8211186e1cca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_optimization(['Adam'], 'traffic', 0.0001, 'transformer', 400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4388d8ac-61e9-488d-a421-a7deb20b0299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5498abac-8ed7-4afe-823c-5e2cadd9f4fb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
