{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zOxVeYlMBiXl",
        "outputId": "d8041b76-ae0f-4d45-b716-899e95d4b143"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Thu May 26 14:06:22 2022       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   46C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n",
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
            "/content/drive/.shortcut-targets-by-id/1G0GaLFeHhbaKAYK9VCyXO3hxE7g0ZCUJ/NC_UFM\n"
          ]
        }
      ],
      "source": [
        "gpu_info = !nvidia-smi\n",
        "gpu_info = '\\n'.join(gpu_info)\n",
        "if gpu_info.find('failed') >= 0:\n",
        "  print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n",
        "  print('and then re-execute this cell.')\n",
        "else:\n",
        "  print(gpu_info)\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "path = '/content/drive/My Drive/NC_UFM/'\n",
        "%cd $path  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "unxtGw_dtEdZ"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "import numpy as np\n",
        "import cvxpy as cp\n",
        "import torch.nn as nn\n",
        "import shutil\n",
        "import os\n",
        "import numpy as np\n",
        "import matplotlib\n",
        "matplotlib.use('Agg')\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.metrics import confusion_matrix\n",
        "from sklearn.utils.multiclass import unique_labels\n",
        "import torch.utils.data as data_utils\n",
        "from scipy import spatial\n",
        "import scipy.linalg as scilin\n",
        "import seaborn as sns\n",
        "from google.colab import files\n",
        "import torch.optim as optim\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.manifold import TSNE\n",
        "import statistics as st\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "import torch.nn.functional as F\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "sns.set()\n",
        "sns.set_context(\"paper\")\n",
        "sns.set_style(\"whitegrid\")\n",
        "# sns.set_facecolor('white')\n",
        " \n",
        "\n",
        "class LogRegCE(nn.Module):\n",
        "\n",
        "    def __init__(self, net, lambda_logreg = 10**(-2)):\n",
        "      super(LogRegCE, self).__init__()\n",
        "      self.lambda_logreg = lambda_logreg\n",
        "\n",
        "    def forward(self, x, target, net):\n",
        "      return F.cross_entropy(x, target) + 0.5*self.lambda_logreg*torch.norm(net.lastlayer.weight@net.x.T,'fro')**2  \n",
        "      \n",
        "class RidgeRegCE(nn.Module):\n",
        "\n",
        "    def __init__(self, net, lambda_ridgereg = 10**(-2)):\n",
        "      super(RidgeRegCE, self).__init__()\n",
        "      self.lambda_ridgereg = lambda_ridgereg\n",
        "\n",
        "    def forward(self, x, target, net):\n",
        "      return F.cross_entropy(x, target) + 0.5*self.lambda_ridgereg*(torch.norm(net.lastlayer.weight,'fro')**2+torch.norm(net.x,'fro')**2)  \n",
        "      # return F.cross_entropy(x, target) + 0.5*self.lambda_logreg*torch.trace(x.T@x)\n",
        "      \n",
        "\n",
        "def get_lr(optimizer):\n",
        "    for param_group in optimizer.param_groups:\n",
        "        return param_group['lr']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "r_KQgiqotQzp"
      },
      "outputs": [],
      "source": [
        "%rm -rf graphs\n",
        "!mkdir graphs\n",
        "!mkdir graphs/tsne_progression\n",
        "!mkdir graphs/tsne_progression/tsne_all\n",
        "!mkdir graphs/tsne_progression/tsne_min_only"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "Pa-ar1iNsB0R"
      },
      "outputs": [],
      "source": [
        "load_data = 0\n",
        "# Notebook settings\n",
        "R = 10\n",
        "maj_classes = [0,1]\n",
        "min_classes = [2,3]\n",
        "classes = maj_classes + min_classes\n",
        "K = len(classes)\n",
        "# n = 500\n",
        "n=10\n",
        "n_c = []\n",
        "for i in range(0,K):\n",
        "    if i in maj_classes:\n",
        "        n_c.append(n)\n",
        "    elif i in min_classes:\n",
        "        n_c.append(n // R)\n",
        "    else:\n",
        "        print(\"Error, the maj min are not divided properly.\")\n",
        "\n",
        "N = sum(n_c)\n",
        "if K > 2:\n",
        " feature_dim = K-1\n",
        "else:\n",
        " feature_dim = 2\n",
        "# feature_dim = 10\n",
        "\n",
        "lr_init             = 1*10**(0)\n",
        "lr = lr_init\n",
        "weight_decay_scheduler = 1  # Ridge-decay\n",
        "cvx_init            = 0\n",
        "weight_decay        = 1*10**(-2)\n",
        "lambda_logreg       = 1*10**(-3)  # Logit-regularization\n",
        "lambda_ridgereg     = 0*10**(-10)\n",
        "momentum            = 0.0\n",
        "lr_decay            = 0.15*10**0\n",
        "bias                = False\n",
        "epochs              = int(1*10**5)\n",
        "# epochs_lr_decay     = [round(epochs/8), 2*round(epochs/8), 3*round(epochs/8), 4*round(epochs/8), 5*round(epochs/8), 6*round(epochs/8), 7*round(epochs/8)]\n",
        "epochs_lr_decay     = []\n",
        "# loss_function       = \"CE\"\n",
        "loss_function       = \"LogRegCE\"\n",
        "# loss_function       = \"RidgeRegCE\"\n",
        "optim_algo          = \"SGD\"\n",
        "class_weights       = torch.tensor([1/R if i in maj_classes else 1.0 for i in range(0,K)]).cuda()\n",
        "\n",
        "# analysis epochs\n",
        "epoch_list          = [1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,\n",
        "                      12,  13,  14,  16,  17,  19,  20,  22,  24,  27,   29,\n",
        "                      32,  35,  38,  42,  45,  50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 110, 120, 130, 150, 175, 200 ]\n",
        "# epoch_list          = [1,   2, 50]                       \n",
        "last_epoch = epoch_list[-1]\n",
        "while last_epoch < epochs - 101:\n",
        "    last_epoch += 100\n",
        "    epoch_list.append(last_epoch)\n",
        "if last_epoch != epochs-1:\n",
        "    epoch_list.append(epochs - 1)\n",
        "\n",
        "\n",
        "  \n",
        "\n",
        "notebook_setting_parameters = [\"K\", \"R\", \"n_c\", \"N\", \"classes\", \"maj_classes\", \"min_classes\", \"lr\", \"weight_decay\", \"momentum\", \"lr_decay\", \"epochs_lr_decay\", \"bias\", \"epochs\", \"loss_function\"]\n",
        "notebook_setting_values     = [K, R, n_c, N, classes, maj_classes, min_classes, lr, weight_decay, momentum, lr_decay, epochs_lr_decay, bias, epochs,  loss_function]\n",
        "notebook_setting_dict = {}\n",
        "for parameter_index in range(0,len(notebook_setting_parameters)):\n",
        "    parameter = notebook_setting_parameters[parameter_index]\n",
        "    values    = notebook_setting_values[parameter_index]\n",
        "    notebook_setting_dict[parameter] = values\n",
        "\n",
        "with open('./graphs/notebook_setting.txt', 'w') as f:\n",
        "    print(notebook_setting_dict, file=f)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "008Zxcpkfs8M",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ef8fca96-9c1f-46ea-f628-ef86ae17ccca"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "LPMnet_one_layer(\n",
            "  (lastlayer): Linear(in_features=3, out_features=4, bias=False)\n",
            "  (relu): ReLU()\n",
            ")\n",
            "Parameter containing:\n",
            "tensor([[-0.3906,  0.1637, -0.3257],\n",
            "        [ 0.2687,  0.5566, -0.0393],\n",
            "        [ 0.5245, -0.1132, -0.5773],\n",
            "        [-0.2184,  0.5567,  0.2548]], device='cuda:0', requires_grad=True)\n",
            "Parameter containing:\n",
            "tensor([[ 2.6890, -2.3735, -1.6890],\n",
            "        [-1.2274,  2.5323,  0.9747],\n",
            "        [ 0.4906, -0.3190,  0.2517],\n",
            "        [ 0.6231,  0.9225,  0.4285],\n",
            "        [-0.0363, -1.8889, -1.5454],\n",
            "        [ 1.6203, -1.4242,  0.3518],\n",
            "        [-0.3656, -1.2114,  1.4143],\n",
            "        [-1.2313,  2.2743, -0.9926],\n",
            "        [ 1.5086,  0.2523, -1.4276],\n",
            "        [ 0.0989,  0.1252,  0.4566],\n",
            "        [-0.5568,  0.2402,  0.1063],\n",
            "        [-0.2387, -1.4342,  0.6548],\n",
            "        [ 0.5405, -2.7642,  0.5749],\n",
            "        [ 0.6118, -1.8912, -0.6141],\n",
            "        [-0.0505,  0.1589, -0.3393],\n",
            "        [-0.4412, -0.5796, -0.9480],\n",
            "        [ 2.1642, -0.0605,  0.6781],\n",
            "        [ 1.0264, -0.6214,  0.6931],\n",
            "        [-0.4447,  0.1179, -1.0811],\n",
            "        [ 1.9213,  0.8664,  0.3129],\n",
            "        [-0.5226, -0.3660, -0.7825],\n",
            "        [ 1.6436, -0.7298, -0.3867]], device='cuda:0', requires_grad=True)\n"
          ]
        }
      ],
      "source": [
        "class Net(nn.Module):\n",
        "\n",
        "    def __init__(self, N = 5000, feature_dim = 512, bias = False):\n",
        "        super(Net, self).__init__()\n",
        "        self.fc1 = nn.Linear(N, feature_dim, bias = False)\n",
        "        self.fc2 = nn.Linear(feature_dim, K, bias = bias)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.fc1(x)\n",
        "        out = self.fc2(out)\n",
        "        return out\n",
        "\n",
        "    def get_features(self,x):\n",
        "        features = self.fc1(x)\n",
        "        return features\n",
        "\n",
        "class LPMnet_one_layer(nn.Module):\n",
        "\n",
        "    def __init__(self, n_c = 1, feature_dim = 512, bias = False):\n",
        "        super(LPMnet_one_layer, self).__init__()\n",
        "        self.feature_dim = feature_dim\n",
        "        self.n_c = n_c\n",
        "        self.N = sum(n_c)\n",
        "        self.K = len(n_c)\n",
        "        if cvx_init == 0:\n",
        "          self.x = nn.Parameter(1*torch.randn([N, feature_dim]), requires_grad=True) # init for H\n",
        "        elif K == 3 and R == 2:\n",
        "          self.x = nn.Parameter(1*torch.tensor([[0.7029, 0.0618],[0.7029, 0.0618],[-0.3515,    0.6119],[-0.3515,    0.6119],[-0.3515,   -0.6737]]), requires_grad=True) # debug: init for H\n",
        "        elif K == 3 and R == 1:\n",
        "          self.x = nn.Parameter(1*torch.tensor([[0.8165,   -0.0000],[-0.4082 ,   0.7071],[-0.4082,   -0.7071]]), requires_grad=True) # debug: init for H\n",
        "        else:\n",
        "          print(\"CVX init NOT loaded\")\n",
        "        self.label = []\n",
        "        for i in range(K):\n",
        "            for j in range(n_c[i]):\n",
        "                self.label.append(i)\n",
        "        self.label = torch.Tensor(self.label)\n",
        "        self.lastlayer = nn.Linear(feature_dim, K, bias=False)\n",
        "        self.relu = nn.ReLU()    \n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.lastlayer(x)\n",
        "        return x        \n",
        "\n",
        "# Create the mode\n",
        "# net = Net(N = N, feature_dim = feature_dim, bias = bias)\n",
        "net = LPMnet_one_layer(n_c = n_c, feature_dim = feature_dim, bias = bias)\n",
        "net.to(device)\n",
        "print(net)\n",
        "\n",
        "x_data = net.x.to(device)\n",
        "y_data = net.label.long().to(device)\n",
        "# torch.nn.init.constant_(net.lastlayer.weight, 0)  # zero init for W\n",
        "if cvx_init == 1:\n",
        "  if K == 3 and R == 2:\n",
        "    W_cvx = torch.tensor([[0.9484, 0],[-0.5426, 0.7779],[-0.4058, -0.7779]]) # K=3, R = 2\n",
        "  elif K == 3 and R == 1:\n",
        "    W_cvx = torch.tensor([[0.8165,         0],[-0.4082,    0.7071],[-0.4082,   -0.7071]]) # K=3, R = 1\n",
        "  else:\n",
        "    print(\"CVX init NOT loaded\")\n",
        "# for ii in range(K):\n",
        "#   for jj in range(feature_dim):\n",
        "#     torch.nn.init.constant_(net.lastlayer.weight[ii,jj],W_cvx[ii,jj])# debug: CVX init for W, K=3, R=2, feature_dim = 2, n_c = [2,2,1]\n",
        "features_init = x_data\n",
        "y_data_one_hot = F.one_hot(y_data)\n",
        "z_hat_transpose = y_data_one_hot -1/K\n",
        "print(net.lastlayer.weight)\n",
        "print(features_init)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from numpy.linalg import svd\n",
        "from numpy.linalg import norm\n",
        "# gram matrices\n",
        "\n",
        "u, s, vh = svd(z_hat_transpose.cpu(), full_matrices=False)\n",
        "\n",
        "s = s[:K-1]\n",
        "u = u[:,:K-1]\n",
        "vh = vh[:K-1,:]\n",
        "\n",
        "print(s.shape)\n",
        "print(u.shape)\n",
        "print(vh.shape)\n",
        "\n",
        "GW_hat = vh.T @ np.diag(s) @ vh\n",
        "GH_hat = u @ np.diag(s) @ u.T\n",
        "GW_hat_norm = torch.tensor(GW_hat / norm(GW_hat,'fro'))\n",
        "GH_hat_norm = torch.tensor(GH_hat / norm(GH_hat,'fro'))\n",
        "print(GW_hat.shape)\n",
        "print(GH_hat.shape)\n",
        "\n",
        "n_c_cumsum = np.cumsum(n_c)\n",
        "new_class_idc = np.append(np.array([0]),n_c_cumsum[:-1])\n",
        "GM_hat = np.zeros((K,K))\n",
        "for row in range(K):\n",
        "  for col in range(K):\n",
        "    GM_hat[row,col] = GH_hat[new_class_idc[row],new_class_idc[col]]\n",
        "GM_hat_norm = torch.tensor(GM_hat/norm(GM_hat,'fro'))\n",
        "\n",
        "GW_etf = np.eye(K) - 1/K\n",
        "GM_etf = GW_etf\n",
        "GW_etf_norm = torch.tensor(GW_etf/norm(GW_etf,'fro'))\n",
        "GM_etf_norm = torch.tensor(GM_etf/norm(GM_etf,'fro'))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IBCoDpd0360H",
        "outputId": "874dcbf5-b7b2-4ca1-d1a2-e6c94abcfb37"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "(3,)\n",
            "(22, 3)\n",
            "(3, 4)\n",
            "(4, 4)\n",
            "(22, 22)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "-x1UbjBq3Zvj",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1a02e22c-c2f2-403b-ffdf-891df2c89377"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "% Completed:\n",
            "0.0\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:282: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2318.)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "lr: \n",
            "1\n",
            "Epoch: 1000\n",
            "Loss: 0.18460777401924133\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 2000\n",
            "Loss: 0.18441452085971832\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 3000\n",
            "Loss: 0.18441584706306458\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 4000\n",
            "Loss: 0.18441586196422577\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 5000\n",
            "Loss: 0.1844158172607422\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 6000\n",
            "Loss: 0.18441587686538696\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 7000\n",
            "Loss: 0.18441587686538696\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 8000\n",
            "Loss: 0.18441587686538696\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 9000\n",
            "Loss: 0.18441587686538696\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "10.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 10000\n",
            "Loss: 0.18441587686538696\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 11000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 12000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 13000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 14000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 15000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 16000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 17000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 18000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 19000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "20.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 20000\n",
            "Loss: 0.1840895116329193\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 21000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 22000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 23000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 24000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 25000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 26000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 27000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 28000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 29000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "30.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 30000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 31000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 32000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 33000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 34000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 35000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 36000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 37000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 38000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 39000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "40.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 40000\n",
            "Loss: 0.18408602476119995\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 41000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 42000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 43000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 44000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 45000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 46000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 47000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 48000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 49000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "50.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 50000\n",
            "Loss: 0.18408599495887756\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 51000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 52000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 53000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 54000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 55000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 56000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 57000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 58000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 59000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "60.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 60000\n",
            "Loss: 0.18408596515655518\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 61000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 62000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 63000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 64000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 65000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 66000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 67000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 68000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 69000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "70.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 70000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 71000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 72000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 73000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 74000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 75000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 76000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 77000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 78000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 79000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "80.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 80000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 81000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 82000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 83000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 84000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 85000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 86000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 87000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 88000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 89000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "% Completed:\n",
            "90.0\n",
            "lr: \n",
            "1\n",
            "Epoch: 90000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 91000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 92000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 93000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 94000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 95000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 96000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 97000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 98000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n",
            "lr: \n",
            "1\n",
            "Epoch: 99000\n",
            "Loss: 0.1840859204530716\n",
            "Train Accuracy: tensor(1., device='cuda:0')\n"
          ]
        }
      ],
      "source": [
        "from re import T\n",
        "\n",
        "if loss_function == \"CE\":\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "elif loss_function == \"LogRegCE\":\n",
        "    criterion = LogRegCE(net = net, lambda_logreg=lambda_logreg) \n",
        "elif loss_function == \"RidgeRegCE\":\n",
        "    criterion = RidgeRegCE(net = net, lambda_ridgereg=lambda_ridgereg) \n",
        "\n",
        "\n",
        "#criterion = nn.CrossEntropyLoss()\n",
        "if optim_algo == \"Adam\":\n",
        "    optimizer = optim.AdamW(net.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "else:\n",
        "    # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)\n",
        "    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    #optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)\n",
        "\n",
        "\n",
        "lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,\n",
        "                                              milestones=epochs_lr_decay,\n",
        "                                              gamma=lr_decay)\n",
        "NC_1_geom     = []\n",
        "NC_2_geom     = []\n",
        "NC_3_geom     = []\n",
        "NC_4_geom     = []\n",
        "NC_4_original = []\n",
        "NC_5 = torch.zeros([K,len(epoch_list)])\n",
        "NC_6_G_error = torch.zeros(len(epoch_list))\n",
        "NC_6_G_error_etf = torch.zeros(len(epoch_list))\n",
        "NC_6_GM_error = torch.zeros(len(epoch_list))\n",
        "NC_6_GM_error_etf = torch.zeros(len(epoch_list))\n",
        "NC_7_Z_error = torch.zeros(len(epoch_list))\n",
        "NC_6_G_error_new = torch.zeros(len(epoch_list))\n",
        "grad_align_error = torch.zeros(len(epoch_list))\n",
        "grad_norm = torch.zeros(len(epoch_list))\n",
        "foo_error = torch.zeros(len(epoch_list))\n",
        "Z_error_ridgereg_nuc_norm = torch.zeros(len(epoch_list))\n",
        "GW_error_ridgereg_nuc_norm = torch.zeros(len(epoch_list))\n",
        "GM_error_ridgereg_nuc_norm = torch.zeros(len(epoch_list))\n",
        "Z_error_ridgereg_nuc_norm_etf = torch.zeros(len(epoch_list))\n",
        "GW_error_ridgereg_nuc_norm_etf = torch.zeros(len(epoch_list))\n",
        "GM_error_ridgereg_nuc_norm_etf = torch.zeros(len(epoch_list))\n",
        "Z_error_ridgereg_nuc_norm_seli = torch.zeros(len(epoch_list))\n",
        "GW_error_ridgereg_nuc_norm_seli = torch.zeros(len(epoch_list))\n",
        "GM_error_ridgereg_nuc_norm_seli = torch.zeros(len(epoch_list))\n",
        "train_accuracies  = []\n",
        "losses        = []\n",
        "reg_losses = []\n",
        "ce_lb_zhu_list = []\n",
        "ce_logreg_loss_lb = []\n",
        "\n",
        "mu_maj_maj_cos_list = []\n",
        "mu_min_min_cos_list = []\n",
        "mu_maj_min_cos_list = []\n",
        "W_maj_maj_cos_list  = []\n",
        "W_min_min_cos_list  = []\n",
        "W_maj_min_cos_list  = []\n",
        "\n",
        "W_norms_list   = []\n",
        "mu_norms_list  = []\n",
        "h_G_norms_list = []\n",
        "mu_c_norms_list = []\n",
        "w_G_norms_list = []\n",
        "w_c_norms_list = []\n",
        "fc_bias_list = []\n",
        "\n",
        "h_i_epochs_list_tsne = []\n",
        "c_i_epochs_list_tsne = []\n",
        "\n",
        "margins = {}\n",
        "Delta_margins = {}\n",
        "margins_logt_norm = {}\n",
        "margins_qmin_norm = {}\n",
        "q_min = []\n",
        "\n",
        "cos=torch.nn.CosineSimilarity(dim=1, eps=1e-08)\n",
        "\n",
        "compare_with_nuc_norm = 0\n",
        "if loss_function == \"RidgeRegCE\":\n",
        "    if R == 10:\n",
        "        if n == 10:\n",
        "            if lambda_ridgereg == 10**(-2):\n",
        "                Z_norm_nuc_min = torch.tensor([[0.1859,    0.1859,    0.1859,    0.1859,    0.1859 ,   0.1859 ,   0.1859,    0.1859 ,   0.1859 ,   0.1859  , -0.0878,   -0.0878,   -0.0878,   -0.0878, -0.0878   ,-0.0878  , -0.0878  , -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0671  , -0.0671],[-0.0878,   -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0878,   -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0878 ,   0.1859  ,  0.1859    ,0.1859   , 0.1859, 0.1859 ,   0.1859,    0.1859   , 0.1859  ,  0.1859 ,   0.1859,   -0.0671,   -0.0671],[-0.0491,   -0.0491,   -0.0491,   -0.0491  , -0.0491,   -0.0491 ,  -0.0491,   -0.0491,   -0.0491,   -0.0491 ,  -0.0491  , -0.0491,   -0.0491,   -0.0491, -0.0491  , -0.0491   ,-0.0491   ,-0.0491,   -0.0491,   -0.0491,    0.1422 ,  -0.0079],[-0.0491 ,  -0.0491,   -0.0491,   -0.0491 ,  -0.0491 ,  -0.0491  , -0.0491,   -0.0491,   -0.0491 ,  -0.0491 ,  -0.0491,   -0.0491 ,  -0.0491 ,  -0.0491  ,-0.0491 ,  -0.0491 ,  -0.0491 ,  -0.0491,   -0.0491,   -0.0491 ,  -0.0079,    0.1422]])\n",
        "                compare_with_nuc_norm = 1\n",
        "            elif lambda_ridgereg == 10**(-3):\n",
        "                Z_norm_nuc_min = torch.tensor([[0.1857 ,   0.1857 ,   0.1857 ,   0.1857,    0.1857 ,   0.1857  ,  0.1857   , 0.1857 ,   0.1857 ,   0.1857,   -0.0789 ,  -0.0789,   -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789  , -0.0789 ,  -0.0675,   -0.0675],[-0.0789 ,  -0.0789  , -0.0789 ,  -0.0789  , -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789,   -0.0789,   -0.0789 ,   0.1857,    0.1857,    0.1857 ,   0.1857 ,  0.1857,    0.1857 ,   0.1857 ,   0.1857 ,   0.1857  ,  0.1857  , -0.0675 ,  -0.0675],[-0.0534 ,  -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534  , -0.0534 ,  -0.0534  , -0.0534  , -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534,    0.1620 ,  -0.0269],[-0.0534 ,  -0.0534 ,  -0.0534  , -0.0534   ,-0.0534 ,  -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534 ,  -0.0534,   -0.0534 ,  -0.0534  , -0.0534 ,  -0.0534 ,  -0.0269 ,   0.1620]])\n",
        "                compare_with_nuc_norm = 1\n",
        "            elif lambda_ridgereg == 10**(-4):\n",
        "                Z_norm_nuc_min = torch.tensor([[0.1855,    0.1855,    0.1855 ,   0.1855 ,   0.1855 ,   0.1855 ,   0.1855 ,   0.1855 ,   0.1855  ,  0.1855  , -0.0746 ,  -0.0746 ,  -0.0746 ,  -0.0746 ,  -0.0746,   -0.0746 ,  -0.0746,   -0.0746 ,  -0.0746,   -0.0746 ,  -0.0667 ,  -0.0667],[-0.0746,   -0.0746 ,  -0.0746 ,  -0.0746  , -0.0746 ,  -0.0746,   -0.0746  , -0.0746 ,  -0.0746 ,  -0.0746 ,   0.1855,    0.1855,    0.1855 ,   0.1855 ,  0.1855  ,  0.1855 ,   0.1855,    0.1855 ,   0.1855,    0.1855 ,  -0.0667 ,  -0.0667 ],[-0.0554 ,  -0.0554 ,  -0.0554,   -0.0554 ,  -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554 ,  -0.0554,   -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554,    -0.0554,   -0.0554,   -0.0554,   -0.0554,   -0.0554 ,  -0.0554  ,  0.1689 ,  -0.0354 ],[-0.0554,   -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554,   -0.0554,   -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554 ,  -0.0554 , -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554 ,  -0.0554  , -0.0554,   -0.0354 ,   0.1689 ]])\n",
        "                compare_with_nuc_norm = 1\n",
        "elif loss_function == \"CE\":\n",
        "    if R == 10:\n",
        "        if weight_decay == 10**(-2):\n",
        "            Z_norm_nuc_min = torch.tensor([[0.1859,    0.1859,    0.1859,    0.1859,    0.1859 ,   0.1859 ,   0.1859,    0.1859 ,   0.1859 ,   0.1859  , -0.0878,   -0.0878,   -0.0878,   -0.0878, -0.0878   ,-0.0878  , -0.0878  , -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0671  , -0.0671],[-0.0878,   -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0878,   -0.0878 ,  -0.0878 ,  -0.0878 ,  -0.0878 ,   0.1859  ,  0.1859    ,0.1859   , 0.1859, 0.1859 ,   0.1859,    0.1859   , 0.1859  ,  0.1859 ,   0.1859,   -0.0671,   -0.0671],[-0.0491,   -0.0491,   -0.0491,   -0.0491  , -0.0491,   -0.0491 ,  -0.0491,   -0.0491,   -0.0491,   -0.0491 ,  -0.0491  , -0.0491,   -0.0491,   -0.0491, -0.0491  , -0.0491   ,-0.0491   ,-0.0491,   -0.0491,   -0.0491,    0.1422 ,  -0.0079],[-0.0491 ,  -0.0491,   -0.0491,   -0.0491 ,  -0.0491 ,  -0.0491  , -0.0491,   -0.0491,   -0.0491 ,  -0.0491 ,  -0.0491,   -0.0491 ,  -0.0491 ,  -0.0491  ,-0.0491 ,  -0.0491 ,  -0.0491 ,  -0.0491,   -0.0491,   -0.0491 ,  -0.0079,    0.1422]])\n",
        "            compare_with_nuc_norm = 1\n",
        "        elif weight_decay == 10**(-3):\n",
        "            Z_norm_nuc_min = torch.tensor([[0.1857 ,   0.1857 ,   0.1857 ,   0.1857,    0.1857 ,   0.1857  ,  0.1857   , 0.1857 ,   0.1857 ,   0.1857,   -0.0789 ,  -0.0789,   -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789  , -0.0789 ,  -0.0675,   -0.0675],[-0.0789 ,  -0.0789  , -0.0789 ,  -0.0789  , -0.0789 ,  -0.0789 ,  -0.0789 ,  -0.0789,   -0.0789,   -0.0789 ,   0.1857,    0.1857,    0.1857 ,   0.1857 ,  0.1857,    0.1857 ,   0.1857 ,   0.1857 ,   0.1857  ,  0.1857  , -0.0675 ,  -0.0675],[-0.0534 ,  -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534  , -0.0534 ,  -0.0534  , -0.0534  , -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534,    0.1620 ,  -0.0269],[-0.0534 ,  -0.0534 ,  -0.0534  , -0.0534   ,-0.0534 ,  -0.0534,   -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534 ,  -0.0534,   -0.0534 ,  -0.0534,   -0.0534 ,  -0.0534  , -0.0534 ,  -0.0534 ,  -0.0269 ,   0.1620]])\n",
        "            compare_with_nuc_norm = 1\n",
        "        elif weight_decay == 10**(-4):\n",
        "            Z_norm_nuc_min = torch.tensor([[0.1855,    0.1855,    0.1855 ,   0.1855 ,   0.1855 ,   0.1855 ,   0.1855 ,   0.1855 ,   0.1855  ,  0.1855  , -0.0746 ,  -0.0746 ,  -0.0746 ,  -0.0746 ,  -0.0746,   -0.0746 ,  -0.0746,   -0.0746 ,  -0.0746,   -0.0746 ,  -0.0667 ,  -0.0667],[-0.0746,   -0.0746 ,  -0.0746 ,  -0.0746  , -0.0746 ,  -0.0746,   -0.0746  , -0.0746 ,  -0.0746 ,  -0.0746 ,   0.1855,    0.1855,    0.1855 ,   0.1855 ,  0.1855  ,  0.1855 ,   0.1855,    0.1855 ,   0.1855,    0.1855 ,  -0.0667 ,  -0.0667 ],[-0.0554 ,  -0.0554 ,  -0.0554,   -0.0554 ,  -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554 ,  -0.0554,   -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554,    -0.0554,   -0.0554,   -0.0554,   -0.0554,   -0.0554 ,  -0.0554  ,  0.1689 ,  -0.0354 ],[-0.0554,   -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554,   -0.0554,   -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554 ,  -0.0554 , -0.0554 ,  -0.0554  , -0.0554 ,  -0.0554 ,  -0.0554  , -0.0554,   -0.0354 ,   0.1689 ]])\n",
        "            compare_with_nuc_norm = 1\n",
        "        elif weight_decay == 10**(-7):\n",
        "            Z_norm_nuc_min = torch.tensor([[0.1852,    0.1852 ,   0.1852  ,  0.1852 ,   0.1852 ,   0.1852 ,   0.1852 ,   0.1852 ,   0.1852 ,   0.1852 ,  -0.0691  , -0.0691 ,  -0.0691  , -0.0691 ,  -0.0691,   -0.0691 ,  -0.0691 ,  -0.0691 ,  -0.0691 ,  -0.0691  , -0.0650,   -0.0650],[-0.0691 ,  -0.0691 ,  -0.0691 ,  -0.0691 ,  -0.0691 ,  -0.0691 ,  -0.0691,   -0.0691  , -0.0691  , -0.0691 ,  0.1852 ,   0.1852,    0.1852 ,   0.1852 ,   0.1852  ,  0.1852 ,   0.1852 ,   0.1852   , 0.1852 ,   0.1852 ,  -0.0650 ,  -0.0650],[-0.0581 ,  -0.0581  , -0.0581,   -0.0581 ,  -0.0581,   -0.0581  , -0.0581 ,  -0.0581,   -0.0581 ,  -0.0581 ,  -0.0581  , -0.0581,   -0.0581 ,  -0.0581 ,  -0.0581,   -0.0581 ,  -0.0581 ,  -0.0581  , -0.0581 ,    -0.0581 ,   0.1765 ,  -0.0464],[-0.0581,   -0.0581 ,  -0.0581 ,  -0.0581 ,  -0.0581 ,  -0.0581 ,  -0.0581 ,  -0.0581  , -0.0581 ,  -0.0581,   -0.0581 ,  -0.0581 ,  -0.0581 ,  -0.0581,   -0.0581 ,  -0.0581,   -0.0581 ,  -0.0581 ,  -0.0581 ,   -0.0581 ,  -0.0464  ,  0.1765 ]])    \n",
        "            compare_with_nuc_norm = 1\n",
        "\n",
        "if compare_with_nuc_norm == 1:\n",
        "  u_la, s_la, vh_la = svd(Z_norm_nuc_min.T.cpu(), full_matrices=False)\n",
        "\n",
        "  s_la = s_la[:K-1]\n",
        "  u_la = u_la[:,:K-1]\n",
        "  vh_la = vh_la[:K-1,:]\n",
        "\n",
        "  print(s_la.shape)\n",
        "  print(u_la.shape)\n",
        "  print(vh_la.shape)\n",
        "\n",
        "  GW_hat_la = vh_la.T @ np.diag(s_la) @ vh_la\n",
        "  GH_hat_la = u_la @ np.diag(s_la) @ u_la.T\n",
        "  GW_hat_norm_la = torch.tensor(GW_hat_la / norm(GW_hat_la,'fro'))\n",
        "  GH_hat_norm_la = torch.tensor(GH_hat_la / norm(GH_hat_la,'fro'))\n",
        "  print(GW_hat_la.shape)\n",
        "  print(GH_hat_la.shape)\n",
        "\n",
        "  # n_c_cumsum = np.cumsum(n_c)\n",
        "  # new_class_idc = np.append(np.array([0]),n_c_cumsum[:-1])\n",
        "  GM_hat_la = np.zeros((K,K))\n",
        "  for row in range(K):\n",
        "    for col in range(K):\n",
        "      GM_hat_la[row,col] = GH_hat_la[new_class_idc[row],new_class_idc[col]]\n",
        "  GM_hat_norm_la = torch.tensor(GM_hat_la/norm(GM_hat_la,'fro'))\n",
        "\n",
        "\n",
        "for c in range(0,K):\n",
        "    margins[c] = {}\n",
        "    Delta_margins[c] = {}\n",
        "    margins_logt_norm[c] = {}\n",
        "    margins_qmin_norm[c] = {}\n",
        "    for c_prime in range(0,K):\n",
        "        margins[c][c_prime] = []\n",
        "        Delta_margins[c][c_prime] = []\n",
        "        margins_logt_norm[c][c_prime] = []\n",
        "        margins_qmin_norm[c][c_prime] = []\n",
        "\n",
        "logits_std_dict = {}\n",
        "logits_std_dict_i_maj_to_maj = {}\n",
        "logits_std_dict_i_maj_to_min = {}\n",
        "logits_std_dict_i_min_to_min = {}\n",
        "logits_std_dict_i_min_to_maj = {}\n",
        "for c in range(0,K):\n",
        "    logits_std_dict[c] = []\n",
        "    logits_std_dict_i_maj_to_maj[c] = []\n",
        "    logits_std_dict_i_maj_to_min[c] = []\n",
        "    logits_std_dict_i_min_to_min[c] = []\n",
        "    logits_std_dict_i_min_to_maj[c] = []\n",
        "\n",
        "epoch_list_idx = 0\n",
        "loss_list = []\n",
        "for epoch in range(0, epochs):\n",
        "    if epoch % (epochs//10) == 0:\n",
        "        print('% Completed:')\n",
        "        print(100*epoch/epochs)\n",
        "\n",
        "    if weight_decay_scheduler == 1:  \n",
        "      if epoch % np.ceil(epochs/10) == 0:\n",
        "        optimizer.param_groups[0]['weight_decay'] = optimizer.param_groups[0]['weight_decay']/10\n",
        "\n",
        "      # if epoch == np.ceil(epochs*0.8):\n",
        "      #   optimizer.param_groups[0]['weight_decay'] = 0  \n",
        "\n",
        "    num_correct_train = 0\n",
        "    h_i = torch.zeros(N, feature_dim)\n",
        "    h_c = []\n",
        "    c_i = torch.zeros(N,1)\n",
        "    counter = 0\n",
        "    num_c = []\n",
        "    mu_c = torch.zeros(K, feature_dim)\n",
        "    h_G = torch.zeros(1, feature_dim)\n",
        "    \n",
        "    NCC_c = {}\n",
        "    for c in range(0,K):\n",
        "        NCC_c[c] = 0\n",
        "\n",
        "    mu_min_min_cos = []\n",
        "    mu_maj_maj_cos = []\n",
        "    mu_maj_min_cos = []\n",
        "    W_min_min_cos  = []\n",
        "    W_maj_maj_cos  = []\n",
        "    W_maj_min_cos  = []\n",
        "\n",
        "    W_norms  = []\n",
        "    mu_norms = []\n",
        "\n",
        "    outputs = net(x_data)    \n",
        "\n",
        "    if loss_function == \"LogRegCE\":\n",
        "      loss = criterion(outputs, y_data,net)\n",
        "    elif loss_function == \"RidgeRegCE\":\n",
        "      loss = criterion(outputs, y_data,net)\n",
        "    else:\n",
        "      loss = criterion(outputs, y_data)\n",
        "\n",
        "    # zero the parameter gradients\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    vec_wh = torch.cat((torch.reshape(net.lastlayer.weight,(-1,)),torch.reshape(net.x,(-1,))))\n",
        "    vec_grad_wh = torch.cat((torch.reshape(net.lastlayer.weight.grad,(-1,)),torch.reshape(net.x.grad,(-1,))))\n",
        "    grad_align_error_val = 1 - torch.dot(vec_wh,-vec_grad_wh)/(torch.norm(vec_wh)*torch.norm(vec_grad_wh))\n",
        "    optimizer.step()\n",
        "\n",
        "    # Scheduler Step for this loss from NC paper\n",
        "    lr_scheduler.step()\n",
        "\n",
        "    if epoch in epoch_list:\n",
        "\n",
        "      # Printing some statistics\n",
        "      if epoch % (epochs//100) == 0:\n",
        "        print(\"lr: \")\n",
        "        for param_group in optimizer.param_groups:\n",
        "          print(param_group['lr'])\n",
        "        print(\"Epoch: \" + str(epoch))\n",
        "        print(\"Loss: \" + str(loss.item()))\n",
        "        print(\"Train Accuracy: \" + str(train_accuracy))\n",
        "      # print(\"..........................................................................\\n\")\n",
        "      loss_list.append(loss.item())\n",
        "\n",
        "      # Extracting the features\n",
        "      features = net.x\n",
        "      train_accuracy = (y_data == torch.max(outputs,1)[1]).sum()/y_data.shape[0]\n",
        "\n",
        "      train_accuracies.append(train_accuracy.to(\"cpu\"))\n",
        "\n",
        "      for i in range(0,y_data.shape[0]):\n",
        "\n",
        "          h_i[counter,:] = features[i,:]\n",
        "          c_i[counter] = y_data[i]\n",
        "          counter += 1\n",
        "\n",
        "      \n",
        "      # Calculating mu_c, h_c, n_c\n",
        "      for c in range(0,K):\n",
        "          c_idxs = (c_i == c).nonzero(as_tuple=True)[0]\n",
        "\n",
        "          if len(c_idxs) == 0: # If no class-c in this batch\n",
        "              continue\n",
        "\n",
        "          h_c_i = h_i[c_idxs,:] \n",
        "          h_c.append(h_c_i)\n",
        "\n",
        "          num_c.append(len(c_idxs))\n",
        "          mu_c[c,:] = torch.sum(h_c_i, 0) / num_c[c]\n",
        "\n",
        "    ###################################################################################################\n",
        "      # NC\n",
        "      simga_W = 0\n",
        "      sigma_B = 0\n",
        "\n",
        "      for c in range(0,K):\n",
        "          for i in range(0,n_c[c]):\n",
        "              h_c_i_bar = (h_c[c][i,:] - mu_c[c]).unsqueeze(1)\n",
        "              simga_W += h_c_i_bar @ h_c_i_bar.T\n",
        "      simga_W = simga_W / N\n",
        "      \n",
        "      for c in range(0,K):\n",
        "          mu_c_bar = (mu_c[c] - h_G.T)\n",
        "          sigma_B += mu_c_bar @ mu_c_bar.T\n",
        "      sigma_B = sigma_B / K\n",
        "\n",
        "      sigma_B = sigma_B.reshape((feature_dim,feature_dim)).detach().numpy()\n",
        "      simga_W = simga_W.detach().numpy()\n",
        "      within_class_variability = (np.trace(simga_W @ scilin.pinv(sigma_B)) / K)\n",
        "      NC_1_geom.append(within_class_variability.item())\n",
        "\n",
        "      W=net.lastlayer.weight.to(\"cpu\")\n",
        "######################################\n",
        "      # Logit margins\n",
        "      for c in range(0, K):\n",
        "          for c_prime in range(0, K):\n",
        "              if bias == True:\n",
        "                  margins[c][c_prime].append((W[c,:] - W[c_prime,:]) @ mu_c[c].T + fc_bias[c] - fc_bias[c_prime])\n",
        "              else:\n",
        "                  margins[c][c_prime].append((W[c,:] - W[c_prime,:]) @ mu_c[c].T)\n",
        "                  tmp = (W[c,:] - W[c_prime,:]) @ mu_c[c].T/torch.log(torch.tensor(epoch+1))/torch.norm(mu_c[c])\n",
        "                  margins_logt_norm[c][c_prime].append(tmp.detach().numpy())\n",
        "\n",
        "      tmp_qmin = margins[0][1][-1]\n",
        "      for ii in range(K):\n",
        "        for jj in range(K):\n",
        "          if margins[ii][jj][-1] > 0:\n",
        "            tmp_qmin2 = torch.min(tmp_qmin,margins[ii][jj][-1])\n",
        "            q_min.append(tmp_qmin2)\n",
        "            \n",
        "      for c in range(0, K):\n",
        "          for c_prime in range(0, K):\n",
        "            if bias == False:\n",
        "              tmp3 = margins[c][c_prime][-1]/q_min[-1]\n",
        "              margins_qmin_norm[c][c_prime].append(tmp3.detach().numpy())\n",
        "              \n",
        "     \n",
        "      GW = W@W.T\n",
        "      GW_norm = GW/torch.norm(GW,p='fro')\n",
        "      NC_6_G_error[epoch_list_idx] = torch.norm(GW_hat_norm-GW_norm,p='fro')/torch.norm(GW_norm,p='fro')\n",
        "      NC_6_G_error_etf[epoch_list_idx] = torch.norm(GW_etf_norm-GW_norm,p='fro')/torch.norm(GW_norm,p='fro')\n",
        "\n",
        "      GM = mu_c@mu_c.T\n",
        "      GM_norm = GM/torch.norm(GM,p='fro')\n",
        "      NC_6_GM_error[epoch_list_idx] = torch.norm(GM_hat_norm-GM_norm,p='fro')/torch.norm(GM_norm,p='fro')\n",
        "      NC_6_GM_error_etf[epoch_list_idx] = torch.norm(GM_etf_norm-GM_norm,p='fro')/torch.norm(GM_norm,p='fro')\n",
        "      ###############################\n",
        "      # NC7: Logit matrix average converging to ETF\n",
        "      etf_K = (torch.eye(K) - 1 / K * torch.ones((K, K))) / pow(K - 1, 0.5)\n",
        "      Z_mat_avg = W@mu_c.T\n",
        "      Z_mat_avg_norm = Z_mat_avg/torch.norm(Z_mat_avg,'fro')\n",
        "      NC_7_Z_error[epoch_list_idx] = torch.norm(etf_K-Z_mat_avg_norm,p='fro')\n",
        "\n",
        "      Z_mat_full = W@features.to(\"cpu\").T\n",
        "      if compare_with_nuc_norm == 1:\n",
        "          Z_error_ridgereg_nuc_norm[epoch_list_idx] = torch.norm(Z_mat_full/torch.norm(Z_mat_full,'fro')-Z_norm_nuc_min/torch.norm(Z_norm_nuc_min,'fro'),'fro')\n",
        "          GW_error_ridgereg_nuc_norm[epoch_list_idx] = torch.norm(GW_hat_norm_la-GW_norm,p='fro')/torch.norm(GW_norm,p='fro')\n",
        "          GM_error_ridgereg_nuc_norm[epoch_list_idx] = torch.norm(GM_hat_norm_la-GM_norm,p='fro')/torch.norm(GM_norm,p='fro')\n",
        "          Z_error_ridgereg_nuc_norm_etf[epoch_list_idx] = torch.norm(Z_mat_full/torch.norm(Z_mat_full,'fro')-z_hat_transpose.to(\"cpu\").T/torch.norm(z_hat_transpose.to(\"cpu\").T,'fro'),'fro')\n",
        "          GW_error_ridgereg_nuc_norm_etf[epoch_list_idx] = torch.norm(GW_etf_norm-GW_norm,p='fro')/torch.norm(GW_norm,p='fro')\n",
        "          GM_error_ridgereg_nuc_norm_etf[epoch_list_idx] = torch.norm(GM_etf_norm-GM_norm,p='fro')/torch.norm(GM_norm,p='fro')\n",
        "          Z_error_ridgereg_nuc_norm_seli[epoch_list_idx] = torch.norm(Z_mat_full/torch.norm(Z_mat_full,'fro')-z_hat_transpose.to(\"cpu\").T/torch.norm(z_hat_transpose.to(\"cpu\").T,'fro'),'fro')\n",
        "          GW_error_ridgereg_nuc_norm_seli[epoch_list_idx] = torch.norm(GW_hat_norm-GW_norm,p='fro')/torch.norm(GW_norm,p='fro')\n",
        "          GM_error_ridgereg_nuc_norm_seli[epoch_list_idx] = torch.norm(GM_hat_norm-GM_norm,p='fro')/torch.norm(GM_norm,p='fro')\n",
        "      ###############################\n",
        "      # Gradient alignment error: \n",
        "\n",
        "      grad_align_error[epoch_list_idx] = grad_align_error_val\n",
        "\n",
        "      ###################################\n",
        "      # Gradient norm\n",
        "      grad_norm[epoch_list_idx] = torch.norm(vec_grad_wh)\n",
        "\n",
        "      ###################################\n",
        "      # CE LogReg lower bound\n",
        "      rho=torch.norm(W@features.to(\"cpu\").T,'fro').detach().numpy()\n",
        "      ce_logreg_loss_lb.append(np.log(1+(K-1)*np.exp(-rho*np.sqrt(K/(sum(n_c)*(K-1)))))+0.5*lambda_logreg*rho**2)\n",
        "\n",
        "      \n",
        "      ###################################\n",
        "      # FO optimality error\n",
        "      foo_error[epoch_list_idx] = torch.norm(W.T@W-features.to(\"cpu\").T@features.to(\"cpu\"),'fro')\n",
        "\n",
        "      epoch_list_idx += 1 \n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "gBs7HLP9yj4U"
      },
      "outputs": [],
      "source": [
        "new_epoch_list = []\n",
        "for tmp in range(len(epoch_list)):\n",
        "  if epoch_list[tmp] <= epoch:\n",
        "    new_epoch_list.append(epoch_list[tmp])\n",
        "epoch_list = new_epoch_list[0:len(new_epoch_list)-1]\n",
        "epochs = epoch"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "timestr = time.strftime(\"%Y%m%d-%H%M%S\")\n",
        "save_name_pkl = 'UFM_R{}_K{}_{}.pkl'.format(R,K,timestr)\n",
        "save_dict = {'n':n, 'K':K, 'R':R, 'n_c':n_c, 'y_data':y_data,\n",
        "        'lr_init':lr_init, 'weight_decay':weight_decay, 'epochs':epochs,\n",
        "       'epoch_list': epoch_list, 'loss_list':loss_list, 'train_accuracies':train_accuracies,\n",
        "       'loss_function':loss_function, 'lambda_logreg':lambda_logreg, 'grad_norm':grad_norm,\n",
        "       'NC_1_geom':NC_1_geom, 'margins':margins, 'NC_6_G_error':NC_6_G_error,\n",
        "       'NC_6_G_error_etf':NC_6_G_error_etf, 'NC_6_GM_error':NC_6_GM_error,\n",
        "       'NC_6_GM_error_etf':NC_6_GM_error_etf, 'NC_7_Z_error':NC_7_Z_error,\n",
        "       'foo_error':foo_error, 'Z_error_ridgereg_nuc_norm':Z_error_ridgereg_nuc_norm,\n",
        "       'GW_error_ridgereg_nuc_norm':GW_error_ridgereg_nuc_norm, \n",
        "       'GM_error_ridgereg_nuc_norm':GM_error_ridgereg_nuc_norm,\n",
        "       }\n",
        "save_pth = path+save_name_pkl\n",
        "import pickle\n",
        "with open(save_pth, 'wb') as f:\n",
        "  pickle.dump(save_dict, f)"
      ],
      "metadata": {
        "id": "afi-aB2mkeuJ"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "timestr = time.strftime(\"%Y%m%d-%H%M%S\")\n",
        "print(timestr)\n",
        "save_name_pkl = 'UFM_R{}_K{}_{}.pkl'.format(2,2,timestr)\n",
        "save_name_pkl"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 55
        },
        "id": "a2oiOTHkCZdr",
        "outputId": "e3284c05-27a5-4a81-c3ba-454cd563545a"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "20220526-140859\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'UFM_R2_K2_20220526-140859.pkl'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "cbPfm2UlyNco",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "57275736-1d34-4867-8630-3f1dbff14fa9"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:136: UserWarning: Data has no positive values, and therefore cannot be log-scaled.\n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:152: UserWarning: Data has no positive values, and therefore cannot be log-scaled.\n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:168: UserWarning: Data has no positive values, and therefore cannot be log-scaled.\n"
          ]
        }
      ],
      "source": [
        "figure_number_additor = 0\n",
        "file_name = \"./graphs/\"\n",
        "file_name_postfix = \"_R_\" + str(R) + \"_Loss_\" + str(loss_function) + \".pdf\"\n",
        "from matplotlib.lines import Line2D\n",
        "\n",
        "leg_size = 20\n",
        "label_size = 25\n",
        "tick_size = 25\n",
        "fig = plt.figure(figsize=(8,6))\n",
        "colors = ['limegreen', 'red', 'dodgerblue']\n",
        "############################################################################################################\n",
        "plt.figure(figure_number_additor + 1)\n",
        "# plt.clf()\n",
        "plt.plot(epoch_list, train_accuracies[0:len(epoch_list)],linewidth=3, color=colors[2])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('Train Accuracy')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "#plt.ylim([-0.1, 1.1])\n",
        "# plt.title('Accuracy for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'train_accuracy' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "############################################################################################################\n",
        "plt.figure(figure_number_additor + 2)\n",
        "# plt.clf()\n",
        "plt.semilogy(epoch_list, loss_list[0:len(epoch_list)],linewidth=3, color=colors[2])\n",
        "# plt.plot(epoch_list, ce_logreg_loss_lb[0:len(epoch_list)],linewidth=3, color=colors[1])\n",
        "plt.legend(['loss','LogRegCE lb'])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('Loss')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "#plt.ylim([0,150])\n",
        "# plt.title('Loss for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'loss' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "############################################################################################################\n",
        "plt.figure(figure_number_additor + 3)\n",
        "# plt.clf()\n",
        "plt.semilogy(epoch_list, grad_norm[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[2])\n",
        "# plt.legend(['loss','lb zhu'])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('Gradient Norm')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "#plt.ylim([0,150])\n",
        "# plt.title('Gradient norm for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'grad_norm' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "############################################################################################################\n",
        "plt.figure(figure_number_additor + 4)\n",
        "# plt.clf()\n",
        "plt.plot(epoch_list, NC_1_geom[0:len(epoch_list)],linewidth=3, color=colors[2])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('Variability Collapse')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "#plt.ylim([-0.0001, 0.004])\n",
        "# plt.title('Variability Collapse for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'NC1' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number = 4\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,NC_6_G_error[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[2])\n",
        "plt.plot(epoch_list,NC_6_G_error_etf[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[1])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('GW_error')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('NC6: GW error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'NC6 GW error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,NC_6_GM_error[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[2])\n",
        "plt.plot(epoch_list,NC_6_GM_error_etf[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[1])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('GM_error')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('NC6: GM error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'NC6 GM error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,NC_7_Z_error[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[2])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('Z_error')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('NC7: Logit Z error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'NC7 Z error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,foo_error[0:len(epoch_list)].detach().numpy(),linewidth=3, color=colors[2])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('FOO_error')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('NC7: FOO error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'FOO error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,Z_error_ridgereg_nuc_norm[0:len(epoch_list)].detach().numpy(),label='$\\lambda$-SELI',linewidth=3, color=colors[2])\n",
        "plt.plot(epoch_list,Z_error_ridgereg_nuc_norm_etf[0:len(epoch_list)].detach().numpy(),label='ETF',linewidth=3, color=colors[0])\n",
        "plt.plot(epoch_list,Z_error_ridgereg_nuc_norm_seli[0:len(epoch_list)].detach().numpy(),label='SELI',linewidth=3, color=colors[1])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('RidgeReg Z error from nuc norm minimizer')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.legend(prop={'size': leg_size},handlelength=3.5)\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('RidgeReg Z error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'RidgeReg Z error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,GW_error_ridgereg_nuc_norm[0:len(epoch_list)].detach().numpy(),label='$\\lambda$-SELI',linewidth=3, color=colors[2])\n",
        "plt.plot(epoch_list,GW_error_ridgereg_nuc_norm_etf[0:len(epoch_list)].detach().numpy(),label='ETF',linewidth=3, color=colors[0])\n",
        "plt.plot(epoch_list,GW_error_ridgereg_nuc_norm_seli[0:len(epoch_list)].detach().numpy(),label='SELI',linewidth=3, color=colors[1])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('RidgeReg Z error from nuc norm minimizer')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.legend(prop={'size': leg_size},handlelength=3.5)\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('RidgeReg Z error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'RidgeReg GW error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "plt.figure(figure_number_additor + graph_number)\n",
        "plt.plot(epoch_list,GM_error_ridgereg_nuc_norm[0:len(epoch_list)].detach().numpy(),label='$\\lambda$-SELI',linewidth=3, color=colors[2])\n",
        "plt.plot(epoch_list,GM_error_ridgereg_nuc_norm_etf[0:len(epoch_list)].detach().numpy(),label='ETF',linewidth=3, color=colors[0])\n",
        "plt.plot(epoch_list,GM_error_ridgereg_nuc_norm_seli[0:len(epoch_list)].detach().numpy(),label='SELI',linewidth=3, color=colors[1])\n",
        "# plt.xlabel('Epoch')\n",
        "# plt.ylabel('RidgeReg Z error from nuc norm minimizer')\n",
        "plt.yscale(\"log\")\n",
        "plt.xscale(\"log\")\n",
        "plt.legend(prop={'size': leg_size},handlelength=3.5)\n",
        "plt.xticks(fontsize=0.8*tick_size)\n",
        "plt.yticks(fontsize=0.8*tick_size)\n",
        "# plt.title('RidgeReg Z error for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "plt.savefig(file_name + 'RidgeReg GM error' + file_name_postfix, format='pdf', dpi=200)\n",
        "# plt.clf()\n",
        "\n",
        "graph_number += 1\n",
        "# Color dictionary\n",
        "plt.clf()\n",
        "colors_dict = {}\n",
        "colors = ['red', 'maroon', 'navy', 'orangered', 'tomato', 'coral', 'blue', 'royalblue', 'darkslateblue', 'purple']\n",
        "for c in range(0, K):\n",
        "    colors_dict[c] = colors[c]\n",
        "leg_size = 20\n",
        "label_size = 25\n",
        "tick_size = 25\n",
        "fig = plt.figure(figsize=(8,6))\n",
        "for c in range(0,K):\n",
        "    graph_number += 1\n",
        "    label_list = []\n",
        "    for c_prime in range(0,K):\n",
        "        if c_prime != c:\n",
        "            plt.figure(figure_number_additor + graph_number)\n",
        "            plt.plot(epoch_list,torch.tensor(margins[c][c_prime])[0:len(epoch_list)].detach().numpy(),label=c_prime,color = colors[c_prime])\n",
        "            plt.xscale(\"log\")\n",
        "            label_list.append(c_prime)\n",
        "    # plt.xlabel('Epoch')        \n",
        "    # plt.ylabel('Margins')\n",
        "    plt.xticks(fontsize=0.8*tick_size)\n",
        "    plt.yticks(fontsize=0.8*tick_size)\n",
        "    plt.legend([str(c) for c in label_list])\n",
        "    # plt.title('Margins for class '+str(c)+ 'for R = ' + str(R) + \" with \" + loss_function + \" Loss\")\n",
        "    plt.savefig(file_name + 'margins for class' + str(c) + file_name_postfix, format='pdf', dpi=200)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "US-8iuNz3pfB"
      },
      "source": [
        "END END *************"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "UFM_GD.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}