{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing relevant libraries and defining the Net class (same class with which the networks were generated)\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F\n",
    "import pickle\n",
    "from numpy.random import RandomState\n",
    "import numpy as np\n",
    "from sklearn.svm import LinearSVC\n",
    "from sklearn.decomposition import PCA\n",
    "from matplotlib import pyplot as plt\n",
    "from typing import Any, Callable, Optional, Tuple\n",
    "from torchvision import datasets, transforms\n",
    "import torch.nn.functional as F\n",
    "import quadprog\n",
    "import copy\n",
    "\n",
    "from utils import LinCKA2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "k880Ih2e3PBf"
   },
   "outputs": [],
   "source": [
    "# Network class\n",
    "k=1\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.layers = nn.ModuleList()\n",
    "        \n",
    "        self.layers+=[nn.Sequential(nn.Conv2d(3, 16*k,  kernel_size=3) , nn.BatchNorm2d(16*k),\n",
    "                      nn.ReLU(inplace=True)), nn.Sequential(nn.Conv2d(16*k, 16*k,  kernel_size=3) , nn.BatchNorm2d(16*k),\n",
    "                      nn.ReLU(inplace=True))] \n",
    "        self.layers+=[nn.Sequential(nn.Conv2d(16*k, 32*k,  kernel_size=3, stride=2),  nn.BatchNorm2d(32*k),\n",
    "                      nn.ReLU(inplace=True))]\n",
    "        self.layers+=[nn.Sequential(nn.Conv2d(32*k, 32*k,  kernel_size=3),  nn.BatchNorm2d(32*k),\n",
    "              nn.ReLU(inplace=True)),nn.Sequential(nn.Conv2d(32*k, 32*k,  kernel_size=3),  nn.BatchNorm2d(32*k),\n",
    "              nn.ReLU(inplace=True))]\n",
    "        self.layers+=[nn.Sequential(nn.Conv2d(32*k, 64*k,  kernel_size=3, stride=2), nn.BatchNorm2d(64*k),\n",
    "                      nn.ReLU(inplace=True))]\n",
    "        self.layers+=[nn.Sequential(nn.Conv2d(64*k, 64*k,  kernel_size=3, padding='valid'), nn.BatchNorm2d(64*k),\n",
    "                      nn.ReLU(inplace=True))]\n",
    "        self.layers+=[nn.Sequential(nn.Conv2d(64*k, 64*k,  kernel_size=1), nn.BatchNorm2d(64*k),\n",
    "                      nn.ReLU(inplace=True))]\n",
    "        self.layers+= [nn.AdaptiveAvgPool2d((1,1))]\n",
    "        self.fc = nn.Linear(64*k, 10)\n",
    "    \n",
    "    def forward(self, x, acts_only=False,all_act=False):\n",
    "        all_acts = []\n",
    "        for i in range(len(self.layers[:-1])):\n",
    "#             all_acts.append(x)\n",
    "            x = self.layers[i](x)\n",
    "            all_acts.append(x)\n",
    "        \n",
    "        x = self.layers[-1](x) #Had to add this since it's not in the loop anymore\n",
    "        x = self.fc(x.view(-1, 64*k))\n",
    "\n",
    "        if all_act:\n",
    "            # all_cts does not return the final output of the network\n",
    "            return all_acts, x\n",
    "        return x\n",
    "    \n",
    "    def forward_embed(self, x, layer_idx = -1):\n",
    "#         print(len(self.layers[:layer_idx]))\n",
    "        for i in range(len(self.layers[:layer_idx])):\n",
    "             x = self.layers[i](x)\n",
    "                \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DxrHSBk_Yhrq",
    "outputId": "45088feb-7b91-4ec4-e3e1-348e8228f309"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (layers): ModuleList(\n",
       "    (0): Sequential(\n",
       "      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))\n",
       "      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))\n",
       "      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))\n",
       "      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n",
       "      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (4): Sequential(\n",
       "      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n",
       "      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))\n",
       "      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (6): Sequential(\n",
       "      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=valid)\n",
       "      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (7): Sequential(\n",
       "      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "    )\n",
       "    (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  )\n",
       "  (fc): Linear(in_features=64, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Importing the saved networks\n",
    "PATH='net_kornblith_'\n",
    "\n",
    "net_all1 = Net()\n",
    "net_all1.load_state_dict(torch.load(PATH+'all_1.zip'))\n",
    "net_all1.eval()\n",
    "net_all1.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Split and train experiment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_activations_by_class(data, labels):\n",
    "    # returns datapoints, a list of class clusters (each element of the list corresponds to all examples from one class)\n",
    "    # indexes\n",
    "    indexes = []\n",
    "    datapoints = []\n",
    "    \n",
    "    for label in range(10): #for label in set(labels): \n",
    "        indexes.append(np.where(labels==label))\n",
    "        datapoints.append(data[indexes[-1]])\n",
    "    \n",
    "    return datapoints, np.array(indexes).squeeze()\n",
    "\n",
    "def split_class_clusters(data, indexes, d, W, n_clusters, split_constant, ortho_d = True, experiment = 'one_class_one_point'):\n",
    "    # data should be of shape n x d (n examples, d features)\n",
    "    # d is the random direction vector\n",
    "    # W comes from lin_svc.coef_ (n_classes, n_features) corresponding to the normal vectors to the decision hyperplanes\n",
    "    data_ = np.copy(data)\n",
    "    \n",
    "    if ortho_d:\n",
    "        Q, R = np.linalg.qr(W.T)\n",
    "        d_ = d.reshape([1,np.max(d.shape)])\n",
    "        d_ -= np.matmul(d_, np.matmul(Q, Q.T))\n",
    "    else:\n",
    "        d_ = d\n",
    "    \n",
    "    projections = np.matmul(data_, d_.T).squeeze()\n",
    "    mean = np.mean(projections)\n",
    "    min_ = np.min(projections)\n",
    "    max_ = np.max(projections)\n",
    "    \n",
    "    translations = np.zeros(data_.shape[0])\n",
    "    \n",
    "    idxs = []\n",
    "    separators = np.linspace(min_,max_,n_clusters+1)\n",
    "    if 'one_point' in experiment: \n",
    "        if n_clusters != 2:\n",
    "            print(\"ERROR: one point experiment won't work because number of clusters is not 2\")\n",
    "        else:\n",
    "            separators = np.linspace(min_,max_,n_clusters)\n",
    "        \n",
    "    for cluster_idx in range(n_clusters):\n",
    "        idx = np.where(projections>=separators[cluster_idx])\n",
    "        if cluster_idx < (n_clusters-1): # Except for last cluster we need to take into account that the projections are not superior to the next separator\n",
    "            idx2 = idx2 = np.where(projections<separators[cluster_idx+1])\n",
    "            idx = np.intersect1d(idx, idx2)\n",
    "        idxs.append(idx)\n",
    "        translations[idxs[-1]] = cluster_idx*split_constant \n",
    "    \n",
    "    translations = np.matmul(np.diag(translations.squeeze()), np.matmul(np.ones([data_.shape[0],1]), d_))\n",
    "    \n",
    "    data_+= translations\n",
    "    return data_, np.array([indexes[i] for i in idxs]).squeeze()\n",
    "\n",
    "def translate_class_clusters(data_per_classes, ortho_d, translation_constant, experiment = ''):\n",
    "    # data should be a list of length num_classes containing n x d (n examples, d features)\n",
    "    # ortho_d is the random direction vector\n",
    "    # experiment can be '' or 'one_class_x' with x being the class that we want to translate, or 'one_point' selected at random\n",
    "    new_data = copy.deepcopy(data_per_classes)\n",
    "    pt_idx = None\n",
    "    \n",
    "    if experiment == 'all_classes':\n",
    "        for data in new_data:\n",
    "            data += translation_constant*np.matmul(np.ones([data.shape[0],1]), ortho_d.reshape([1, data.shape[1]]))\n",
    "        new_data = np.concatenate(new_data)\n",
    "        \n",
    "    elif 'one_class' in experiment:\n",
    "        class_idx = int(experiment.split('_')[2])\n",
    "        \n",
    "        for idx, data in enumerate(new_data):\n",
    "            if idx == class_idx:\n",
    "                data += translation_constant*np.matmul(np.ones([data.shape[0],1]), ortho_d.reshape([1, data.shape[1]]))\n",
    "                \n",
    "        new_data = np.concatenate(new_data)\n",
    "#         translations = np.zeros([num_pts,1])\n",
    "#         translations[num_pts_per_class*class_idx:num_pts_per_class*(class_idx+1)] = np.ones([num_pts_per_class,1])\n",
    "#         translations = translation_constant*np.matmul(translations, ortho_d.reshape([1, num_dims]))\n",
    "\n",
    "    elif experiment == 'one_point':\n",
    "        new_data = np.concatenate(new_data)\n",
    "        pt_idx = np.random.choice(len(new_data),1)\n",
    "        new_data[pt_idx] += translation_constant*ortho_d\n",
    "        \n",
    "    return new_data, pt_idx # need to return pt_idx to include the outlier into the CKA calculations\n",
    "\n",
    "def test_cka_lin_sep(data_per_classes, indexes, lin_svc,\n",
    "                     num_clusters = 2,\n",
    "                     distance = 100,\n",
    "                     splitting_dir='num_solve',\n",
    "                     num_pts_cka = 10000,\n",
    "                     seed = 0,\n",
    "                     mod = 'split', # 'split' or 'translate'\n",
    "                     experiment = \"one_class_one_point\"):\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "    if mod == 'split':\n",
    "        print('Split; Number of clusters: {}; Distance between clusters: {}; Splitting direction: '.format(num_clusters, distance)+splitting_dir+'; Number of points to compute CKA: {}'.format(num_pts_cka))\n",
    "    elif mod == 'translate':\n",
    "        print('Translate; Experiment: '+experiment+'; Distance: {}'.format(distance))\n",
    "    \n",
    "    labels = np.zeros(len(data_per_classes[0]))\n",
    "    split_data = []\n",
    "    split_indexes = []\n",
    "    \n",
    "    if splitting_dir == 'num_solve':\n",
    "        print('Numerically find a distance for all classes')\n",
    "        \n",
    "        dim = data_per_classes[0].shape[1]\n",
    "        Q, R = np.linalg.qr(lin_svc.coef_.T)\n",
    "            \n",
    "        M = Q.T\n",
    "        # P = np.dot(M.T, M) # Not positive definite\n",
    "        epsilon = 1e-7\n",
    "        P = np.dot(M.T, M) + epsilon*np.eye(M.shape[1]) # Adding epsilon * identity to make it positive definite\n",
    "        q = -np.dot(M.T, np.zeros(10))\n",
    "        G = -np.eye(dim)\n",
    "        # h = np.zeros(256) # returns all zeros\n",
    "        h =  -np.ones(dim)*0.1\n",
    "            \n",
    "        direction = quadprog_solve_qp(P, q, G, h)\n",
    "        norm = np.sum(direction**2)**(0.5)\n",
    "        direction = direction/norm\n",
    "    \n",
    "    if mod == 'split':\n",
    "        # Iterate the splitting for each class\n",
    "        for class_idx, class_data in enumerate(data_per_classes):\n",
    "        \n",
    "            # Define the direction along which to split the data\n",
    "            if 'pc' in splitting_dir:\n",
    "                direction = PCA(n_components = int(splitting_dir[2:])).fit(class_data).components_[int(splitting_dir[2:])-1]\n",
    "            if splitting_dir == 'random':\n",
    "                direction = np.random.normal(0, 1, class_data.shape[1])\n",
    "                norm = np.sum(direction**2)**(0.5)\n",
    "                direction = direction/norm\n",
    "                direction = np.absolute(direction)\n",
    "        \n",
    "            # Split the data\n",
    "            if 'one_class' in experiment:\n",
    "                if class_idx > 0: dist_clusters = 0\n",
    "                \n",
    "            splits = split_class_clusters(class_data, indexes[class_idx], direction, lin_svc.coef_, num_clusters, dist_clusters, experiment = experiment)\n",
    "            split_data.append(splits[0])\n",
    "            split_indexes.append([i for i in splits[1]])\n",
    "            if class_idx != 0: labels = np.concatenate([labels, class_idx*np.ones(split_data[-1].shape[0])])\n",
    "        \n",
    "        split_data = np.concatenate(split_data)\n",
    "        mod_data = split_data\n",
    "        \n",
    "    elif mod == 'translate':\n",
    "        mod_data, outlier_idx = translate_class_clusters(data_per_classes, direction, distance, experiment)\n",
    "        labels = np.array([j for j in range(len(indexes)) for i in range(len(indexes[0]))])\n",
    "    \n",
    "    \n",
    "    lin_sep = lin_svc.score(mod_data, labels)\n",
    "    print(\"Accuracy of the linear SVM classifier on the split data: {}\".format(lin_sep))\n",
    "    \n",
    "    mod_data = torch.Tensor(mod_data).cuda()\n",
    "    original_data = torch.Tensor(np.concatenate(data_per_classes)).cuda()\n",
    "    # CKA values\n",
    "    CKA = LinCKA2()\n",
    "    num_classes = len(data_per_classes)\n",
    "    num_pts_cka_per_class = int(num_pts_cka/num_classes)\n",
    "    num_pts_per_class = int(len(data_per_classes[0])) # assumes equal number of images per class\n",
    "    perm = np.concatenate(np.array([np.random.choice(np.arange(i*num_pts_per_class, (i+1)*num_pts_per_class), num_pts_cka_per_class) for i in range(num_classes)]))\n",
    "#     print(outlier_idx)\n",
    "    if mod == 'translate' and experiment == 'one_point': perm = np.concatenate([perm, outlier_idx])\n",
    "    cka = CKA(original_data[perm], mod_data[perm]).item()\n",
    "    print(\"Cka between {} original vs split pts: {}\".format(num_pts_cka, cka))\n",
    "    return lin_sep, cka, mod_data.cpu().numpy(), np.concatenate(indexes), split_indexes\n",
    "\n",
    "def quadprog_solve_qp(P, q, G=None, h=None, A=None, b=None):\n",
    "    qp_G = .5 * (P + P.T)   # make sure P is symmetric\n",
    "    qp_a = -q\n",
    "    if A is not None:\n",
    "        qp_C = -numpy.vstack([A, G]).T\n",
    "        qp_b = -numpy.hstack([b, h])\n",
    "        meq = A.shape[0]\n",
    "    else:  # no equality constraint\n",
    "        qp_C = -G.T\n",
    "        qp_b = -h\n",
    "        meq = 0\n",
    "    return quadprog.solve_qp(qp_G, qp_a, qp_C, qp_b, meq)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# Import data \n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                  std=[0.229, 0.224, 0.225])\n",
    "\n",
    "transform_val = transforms.Compose([transforms.ToTensor(), normalize]) \n",
    "transform_train =  transforms.Compose([transforms.ToTensor(), normalize]) \n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "##### Cifar Data\n",
    "cifar_data = datasets.CIFAR10(root='data/',train=True, transform=transform_train, download=True)\n",
    "cifar_data_test = datasets.CIFAR10(root='data/',train=False, transform=transform_val, download=True)\n",
    "n=10000\n",
    "val_loaderx = torch.utils.data.DataLoader(cifar_data_test,\n",
    "                                           batch_size=50000, \n",
    "                                           shuffle=False)\n",
    "data, labels = iter(val_loaderx).next()\n",
    "data = data.to(device)\n",
    "\n",
    "train_loaderx = torch.utils.data.DataLoader(cifar_data,\n",
    "                                           batch_size=50000, \n",
    "                                           shuffle=False)\n",
    "train_data, train_labels = iter(train_loaderx).next()\n",
    "train_data = train_data.to(device)\n",
    "train_labels = train_labels.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of a linear SVM classifier on the original data: 0.9099\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/cvmfs/ai.mila.quebec/apps/x86_64/debian/anaconda/3/lib/python3.7/site-packages/sklearn/svm/base.py:929: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  \"the number of iterations.\", ConvergenceWarning)\n"
     ]
    }
   ],
   "source": [
    "layer_idx = -1\n",
    "\n",
    "train_activations, _ = net_all1.forward(train_data, all_act=True)\n",
    "train_act = train_activations[layer_idx].reshape(train_activations[layer_idx].shape[0],-1).detach().cpu().numpy()\n",
    "train_act_tensors = train_activations[layer_idx]\n",
    "del train_activations\n",
    "\n",
    "# Linear separability:\n",
    "lin_svc = LinearSVC()\n",
    "lin_svc.fit(train_act, train_labels)\n",
    "original_lin_sep = lin_svc.score(train_act, train_labels)\n",
    "print(\"Accuracy of a linear SVM classifier on the original data: {}\".format(original_lin_sep))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test (don't run)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Translate; Experiment: all_classes; Distance: 10000\n",
      "Numerically find a distance for all classes\n",
      "10000\n",
      "Accuracy of the linear SVM classifier on the split data: 0.91304\n"
     ]
    }
   ],
   "source": [
    "distance = 10000\n",
    "splitting_dir = 'num_solve'\n",
    "num_pts_cka = 10000\n",
    "seed = 0\n",
    "mod = 'translate'\n",
    "experiment = 'all_classes' # '' or '_one_class' or '_one_class_one_point'\n",
    "\n",
    "# def test_cka_lin_sep(data_per_classes, indexes, lin_svc,\n",
    "#                      num_clusters = 2,\n",
    "#                      distance = 100,\n",
    "#                      splitting_dir='num_solve',\n",
    "#                      num_pts_cka = 10000,\n",
    "#                      seed = 0,\n",
    "#                      mod = 'split', # 'split' or 'translate'\n",
    "#                      experiment = \"one_class_one_point\"):\n",
    "\n",
    "data_per_classes, indexes = get_activations_by_class(train_act, train_labels)\n",
    "lin_sep, cka, split_embeds, indexes, split_indexes = test_cka_lin_sep(data_per_classes, indexes, lin_svc, distance = distance, splitting_dir = splitting_dir, num_pts_cka = num_pts_cka, seed = seed, mod=mod, experiment = experiment)\n",
    "\n",
    "# if layer_idx == -1 or layer_idx == -2:\n",
    "#     sorted_split_embeds = torch.Tensor(split_embeds[np.argsort(indexes)]).reshape([50000, 64, 2, 2])\n",
    "# elif layer_idx == -3:\n",
    "#     sorted_split_embeds = torch.Tensor(split_embeds[np.argsort(indexes)]).reshape([50000, 64, 4, 4])\n",
    "# elif layer_idx == -4:\n",
    "#     sorted_split_embeds = torch.Tensor(split_embeds[np.argsort(indexes)]).reshape([50000, 32, 9, 9])\n",
    "# elif layer_idx == -5:\n",
    "#     sorted_split_embeds = torch.Tensor(split_embeds[np.argsort(indexes)]).reshape([50000, 32, 11, 11])\n",
    "# torch.save(sorted_split_embeds, 'data/cifar10_sorted_split_layer{}_embeds_{}num-clusters_{}dist-clusters_'.format(layer_idx, num_clusters, dist_clusters)+splitting_dir+'_{}pts-cka_{}seed'.format(num_pts_cka,seed)+experiment+'.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Translation of a whole class in a direction that doesn't affect linear separability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed 0\n",
      "Translate; Experiment: all_classes; Distance: 1\n",
      "Numerically find a distance for all classes\n",
      "Accuracy of the linear SVM classifier on the split data: 0.9099\n"
     ]
    }
   ],
   "source": [
    "data_per_classes, indexes = get_activations_by_class(train_act, train_labels)\n",
    "\n",
    "num_pts_cka = 10000\n",
    "num_seeds = 10\n",
    "c_list = [1, 5, 10, 25, 50, 100, 500, 1000, 2500, 5000, 7500, 1e4, 1.5e4, 2e4]\n",
    "experiments = ['all_classes']\n",
    "experiments.extend([f'one_class_{i}' for i in range(10)])\n",
    "mod = 'translate'\n",
    "\n",
    "\n",
    "data = np.zeros([len(experiments), num_seeds, len(c_list),2])\n",
    "for i1, experiment in enumerate(experiments):\n",
    "    for seed in range(num_seeds):\n",
    "        print(f'seed {seed}')\n",
    "        for i2, c in enumerate(c_list):\n",
    "            data[i1, seed, i2] = test_cka_lin_sep(data_per_classes, indexes, lin_svc, distance = c, num_pts_cka = num_pts_cka, seed = seed, mod=mod, experiment = experiment)[:2]\n",
    "\n",
    "            \n",
    "np.save('cifar10_translation_data_lincka2_layer{}.npy'.format(layer_idx), data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This still seems to hold with LinCKA2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Translation of a single point (outlier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_per_classes, indexes = get_activations_by_class(train_act, train_labels)\n",
    "\n",
    "num_pts_cka = 10000\n",
    "num_seeds = 10\n",
    "c_list = [1, 5, 10, 25, 50, 100, 500, 1000, 2500, 5000, 7500, 1e4, 1.5e4, 2e4, 5e4, 1e5]\n",
    "experiments = ['one_point']\n",
    "mod = 'translate'\n",
    "\n",
    "\n",
    "data = np.zeros([len(experiments), num_seeds, len(c_list),2])\n",
    "for i1, experiment in enumerate(experiments):\n",
    "    for seed in range(num_seeds):\n",
    "        print(f'seed {seed}')\n",
    "        for i2, c in enumerate(c_list):\n",
    "            data[i1, seed, i2] = test_cka_lin_sep(data_per_classes, indexes, lin_svc, distance = c, num_pts_cka = num_pts_cka, seed = seed, mod=mod, experiment = experiment)[:2]\n",
    "\n",
    "data = data.squeeze()\n",
    "np.save('cifar10_one_pt_translation_data_lincka2_layer{}.npy'.format(layer_idx), data)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Basic_CKA_ForSharing.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (main_env)",
   "language": "python",
   "name": "main_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "522e0f024813495bb45905626fdccd14": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_8de220001a4c42e6b2bf369d63e17658",
       "IPY_MODEL_9a5b5204bd3847d485f8a139949a79f4",
       "IPY_MODEL_b0283df5d40143789293e11c9428d543"
      ],
      "layout": "IPY_MODEL_916a6726fe5b4b8a949d1d6b1d4bfec1"
     }
    },
    "6d746b9790c744ac91ff6c7cc852f9f7": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "70d14a62a0e94362bea840e44b5719e3": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "890787b726f4485199dcb2a059ef92a3": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "8de220001a4c42e6b2bf369d63e17658": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_6d746b9790c744ac91ff6c7cc852f9f7",
      "placeholder": "​",
      "style": "IPY_MODEL_ad439d60a52447469bcdce66a8e3c0bd",
      "value": ""
     }
    },
    "8ec5fa1d3675438ba3e5097f6f605325": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "916a6726fe5b4b8a949d1d6b1d4bfec1": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9a5b5204bd3847d485f8a139949a79f4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_e34af8bd1fe3408cbd99341d4271f845",
      "max": 170498071,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_890787b726f4485199dcb2a059ef92a3",
      "value": 170498071
     }
    },
    "ad439d60a52447469bcdce66a8e3c0bd": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "b0283df5d40143789293e11c9428d543": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_70d14a62a0e94362bea840e44b5719e3",
      "placeholder": "​",
      "style": "IPY_MODEL_8ec5fa1d3675438ba3e5097f6f605325",
      "value": " 170499072/? [00:05&lt;00:00, 32259211.76it/s]"
     }
    },
    "e34af8bd1fe3408cbd99341d4271f845": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
