{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import scipy\n",
    "from scipy.stats import ortho_group\n",
    "import scipy.linalg\n",
    "from models import NeuralNet, CNN, CNN_FC, CNN_OneD\n",
    "import torch\n",
    "from torch import nn\n",
    "from advertorch.attacks import LinfPGDAttack, L2PGDAttack, DDNL2Attack\n",
    "import ipdb\n",
    "import itertools\n",
    "# from ipynb.fs.full.orthogonal_vectors import train, test_pgd\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, data, labels, epochs = 1000):\n",
    "    model.train()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    batch_size = data.shape[0]\n",
    "    for epoch in range(epochs):\n",
    "        if epoch==epochs//2:\n",
    "            for param_group in optimizer.param_groups:\n",
    "                param_group['lr'] /=10\n",
    "        correct = 0\n",
    "        total_loss = 0\n",
    "#         ipdb.set_trace()\n",
    "        for idx in range(data.shape[0]//batch_size):\n",
    "            batch_data, batch_labels = data[idx*batch_size:(idx+1)*batch_size], labels[idx*batch_size:(idx+1)*batch_size]\n",
    "            out = model(batch_data)\n",
    "            loss = criterion(out, batch_labels)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "#             ipdb.set_trace()\n",
    "            pred = out.argmax(1)\n",
    "            correct += (pred == batch_labels).sum().item()\n",
    "            total_loss += loss.item()*batch_size\n",
    "        if epoch%(epochs/5)==0:\n",
    "            print(\"Epoch: {}, Accuracy: {:.2f}, Loss: {:.5f}\".format(epoch, 100*correct/data.shape[0], total_loss/data.shape[0]))\n",
    "    return model\n",
    "\n",
    "def test_pgd(model, data, labels, epsilon, res_row):\n",
    "    model.eval()\n",
    "    criterion = nn.CrossEntropyLoss(reduction='sum')\n",
    "    adversary = L2PGDAttack(model, loss_fn=criterion, nb_iter=100, eps_iter=epsilon/50,\n",
    "                                rand_init=True, eps=epsilon, clip_min=data.min().item(), clip_max=data.max().item(), targeted=False)\n",
    "    # Set requires_grad attribute of tensor. Important for Attack\n",
    "    data.requires_grad = True\n",
    "#     ipdb.set_trace()\n",
    "    perturbed_data = adversary.perturb(data, labels)\n",
    "    new_out = model(perturbed_data)\n",
    "#     ipdb.set_trace()\n",
    "    pred = new_out.argmax(1)\n",
    "    correct = (pred==labels).sum()\n",
    "    res_row.append(100*float(correct)/labels.shape[0])\n",
    "#     print(\"Epsilon: {}, Accuracy: {}\".format(epsilon, 100*float(correct)/labels.shape[0]))\n",
    "    \n",
    "def test_ddn(model, data, labels, res_row):\n",
    "    model.eval()\n",
    "    criterion = nn.CrossEntropyLoss(reduction='sum')\n",
    "    adversary = DDNL2Attack(model, nb_iter=1000, quantize = False, clip_min=data.min().item(), clip_max=data.max().item(),\n",
    "                           loss_fn=criterion,)\n",
    "#     adversary = L2PGDAttack(model, loss_fn=nn.MSELoss(reduction=\"sum\"), nb_iter=40, eps_iter=epsilon/20,\n",
    "#                                 rand_init=True, eps=epsilon, clip_min=data.min().item(), clip_max=data.max().item(), targeted=False)\n",
    "    # Set requires_grad attribute of tensor. Important for Attack\n",
    "    data.requires_grad = True\n",
    "#     ipdb.set_trace()\n",
    "    perturbed_data = adversary.perturb(data, labels)\n",
    "    new_out = model(perturbed_data)\n",
    "    pred = new_out.argmax(1)\n",
    "    correct = (pred==labels).sum()\n",
    "#     ipdb.set_trace()\n",
    "    l2_distances = np.linalg.norm((data - perturbed_data).detach().cpu().numpy(),\n",
    "                                  ord=2, axis=1)\n",
    "    res_row.append(np.mean(l2_distances))\n",
    "    res_row.append(np.median(l2_distances))\n",
    "#     print(\"Samples: {}, Median L2: {},\"\n",
    "#           \"Mean L2: {}, Accuracy: {}\".format(l2_distances.shape[0], l2_distances.median(),\n",
    "#                                              l2_distances.mean(), 100*float(correct)/labels.shape[0]))\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.random.ran"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def orthogonal_frequencies(N, d=1000):\n",
    "    x = np.linspace(0, 1, d)\n",
    "#     import ipdb\n",
    "#     ipdb.set_trace()\n",
    "    X = np.array([], dtype=np.float64).reshape((0, d))\n",
    "    for k in range(1, N//2+1):\n",
    "        vec_sin = np.sin(2*k*np.pi*x)\n",
    "        vec_sin /= np.linalg.norm(vec_sin, ord=2)\n",
    "        vec_cos = np.cos(2*k*np.pi*x)\n",
    "        vec_cos /= np.linalg.norm(vec_cos, ord=2)\n",
    "        X = np.vstack([vec_sin, vec_cos, X])\n",
    "    for k in range(N//2+1, N+1):\n",
    "        vec_sin = np.sin(2*k*np.pi*x)\n",
    "        vec_sin /= np.linalg.norm(vec_sin, ord=2)\n",
    "        vec_cos = np.cos(2*k*np.pi*x)\n",
    "        vec_cos /= np.linalg.norm(vec_cos, ord=2)\n",
    "        X = np.vstack([vec_sin, vec_cos, X])\n",
    "    return X\n",
    "\n",
    "def orthogonal_frequencies_odd_even(N, d=1000):\n",
    "    x = np.linspace(0, 1, d)\n",
    "#     import ipdb\n",
    "#     ipdb.set_trace()\n",
    "    X = np.array([], dtype=np.float64).reshape((0, d))\n",
    "    for k in range(1, N, 2):\n",
    "        vec_sin = np.sin(2*k*np.pi*x)\n",
    "        vec_sin /= np.linalg.norm(vec_sin, ord=2)\n",
    "        vec_cos = np.cos(2*k*np.pi*x)\n",
    "        vec_cos /= np.linalg.norm(vec_cos, ord=2)\n",
    "        X = np.vstack([vec_sin, vec_cos, X])\n",
    "    for k in range(2, N+1, 2):\n",
    "        vec_sin = np.sin(2*k*np.pi*x)\n",
    "        vec_sin /= np.linalg.norm(vec_sin, ord=2)\n",
    "        vec_cos = np.cos(2*k*np.pi*x)\n",
    "        vec_cos /= np.linalg.norm(vec_cos, ord=2)\n",
    "        X = np.vstack([vec_sin, vec_cos, X])\n",
    "    return X\n",
    "\n",
    "def orthogonal_frequencies_random(N, d=1000):\n",
    "    x = np.linspace(0, 1, d)\n",
    "#     import ipdb\n",
    "#     ipdb.set_trace()\n",
    "    X = np.array([], dtype=np.float64).reshape((0, d))\n",
    "    random_permutation = np.random.permutation(np.arange(1, N+1))\n",
    "    for k in random_permutation:\n",
    "        vec_sin = np.sin(2*k*np.pi*x)\n",
    "        vec_sin /= np.linalg.norm(vec_sin, ord=2)\n",
    "        vec_cos = np.cos(2*k*np.pi*x)\n",
    "        vec_cos /= np.linalg.norm(vec_cos, ord=2)\n",
    "        X = np.vstack([vec_sin, vec_cos, X])\n",
    "\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4000 Dimension, 50 Samples per class\n",
      "Epoch: 0, Accuracy: 50.00, Loss: 0.69335\n",
      "Epoch: 400, Accuracy: 100.00, Loss: 0.00010\n",
      "Epoch: 800, Accuracy: 100.00, Loss: 0.00003\n",
      "Epoch: 1200, Accuracy: 100.00, Loss: 0.00002\n",
      "Epoch: 1600, Accuracy: 100.00, Loss: 0.00002\n",
      "0.67422,0.33807\n",
      "-----------------------\n",
      "4000 Dimension, 100 Samples per class\n",
      "Epoch: 0, Accuracy: 50.00, Loss: 0.69359\n",
      "Epoch: 400, Accuracy: 100.00, Loss: 0.00013\n",
      "Epoch: 800, Accuracy: 100.00, Loss: 0.00004\n",
      "Epoch: 1200, Accuracy: 100.00, Loss: 0.00003\n",
      "Epoch: 1600, Accuracy: 100.00, Loss: 0.00002\n",
      "0.76736,0.32183\n",
      "-----------------------\n",
      "4000 Dimension, 200 Samples per class\n",
      "Epoch: 0, Accuracy: 50.00, Loss: 0.69343\n",
      "Epoch: 400, Accuracy: 100.00, Loss: 0.00019\n",
      "Epoch: 800, Accuracy: 100.00, Loss: 0.00006\n",
      "Epoch: 1200, Accuracy: 100.00, Loss: 0.00004\n",
      "Epoch: 1600, Accuracy: 100.00, Loss: 0.00003\n",
      "0.83469,0.30428\n",
      "-----------------------\n"
     ]
    }
   ],
   "source": [
    "# samples = [50, 100, 150, 200, 250, 300, 400, 450]\n",
    "# samples = [200, 300, 400, 450]\n",
    "samples = [50, 100, 200]\n",
    "# samples = [100]\n",
    "input_shape = 4000\n",
    "data_all = []\n",
    "for k in samples:\n",
    "    res_row = []\n",
    "#     X = orthogonal_frequencies(k, input_shape)\n",
    "    X = orthogonal_frequencies_odd_even(k, input_shape)\n",
    "#     X = orthogonal_frequencies_random(k, input_shape)\n",
    "    d1, d2 = X[:k], X[k:]\n",
    "#     d1, d2 = x[:k],x[input_shape//2:input_shape//2+k]\n",
    "#     d1, d2 = add_linear_seperator(d1, d2, p)\n",
    "#         d1, d2 = x[:k], -x[:k]\n",
    "    data = torch.tensor(np.concatenate((d1, d2))).float()\n",
    "    labels = torch.tensor(np.concatenate(([0]*d1.shape[0], [1]*d2.shape[0])))\n",
    "    idx_random = np.random.permutation(data.shape[0])\n",
    "    data, labels = data[idx_random], labels[idx_random]\n",
    "    print(\"{} Dimension, {} Samples per class\".format(input_shape, k))\n",
    "#     model = NeuralNet(input_shape, 1, 100, 2, no_final=True)\n",
    "    model = CNN_OneD(hidden_channels=128, kernel_size=input_shape//2,\n",
    "                         num_classes = 2, no_final=True)\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    model = model.to(device)\n",
    "    data = data.to(device)\n",
    "    labels = labels.to(device)\n",
    "    model = train(model, data, labels, 2000)\n",
    "#     print(\"-----------------------\")\n",
    "#     print(\"-----------------------\")\n",
    "#     for epsilon in [0.04, 0.06, 0.07, 0.08, 0.1, 0.15, 0.2]:\n",
    "#         test_pgd(model, data, labels, epsilon, res_row)\n",
    "#     print(\"-----------------------\")\n",
    "#     print(\"-----------------------\")\n",
    "    test_ddn(model, data, labels, res_row)\n",
    "    data_all.append(res_row)\n",
    "    print(','.join(['{:.5f}'.format(x) for x in res_row]))\n",
    "    print(\"-----------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "FC\n",
    "0.11425,0.09780\n",
    "0.09180,0.08079\n",
    "0.07655,0.07280"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.00000000e+00,  1.86000927e-02,  3.00887079e-02, ...,\n",
       "        -3.00887079e-02, -1.86000927e-02,  2.48539593e-16],\n",
       "       [ 3.16148739e-02,  2.55711281e-02,  9.75062993e-03, ...,\n",
       "         9.75062993e-03,  2.55711281e-02,  3.16148739e-02],\n",
       "       [ 0.00000000e+00,  1.84388973e-02,  2.99636804e-02, ...,\n",
       "        -2.99636804e-02, -1.84388973e-02,  1.61253138e-15],\n",
       "       ...,\n",
       "       [ 3.16148739e-02,  3.16134684e-02,  3.16092519e-02, ...,\n",
       "         3.16092519e-02,  3.16134684e-02,  3.16148739e-02],\n",
       "       [ 0.00000000e+00,  9.94202750e-05,  1.98839568e-04, ...,\n",
       "        -1.98839568e-04, -9.94202750e-05, -7.74728349e-18],\n",
       "       [ 3.16148739e-02,  3.16147177e-02,  3.16142492e-02, ...,\n",
       "         3.16142492e-02,  3.16147177e-02,  3.16148739e-02]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = orthogonal_frequencies(50, 4000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 4000)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_dot = -1\n",
    "i, j = -1, -1\n",
    "for i in range(1,50):\n",
    "    for j in range(1, 50):\n",
    "        if i==j:\n",
    "            continue\n",
    "        curr_dot = np.dot(X[i], X[j])\n",
    "        if curr_dot > max_dot:\n",
    "            max_dot = curr_dot\n",
    "            max_i, max_j = i, j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0004998750312485997\n"
     ]
    }
   ],
   "source": [
    "print(max_dot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-fast_adversarial] *",
   "language": "python",
   "name": "conda-env-.conda-fast_adversarial-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
