{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from utils import LinCKA2\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import special_ortho_group\n",
    "import sys\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 Gaussians (or 2 cubes) translation exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda = torch.device('cuda')\n",
    "\n",
    "def get_cka_test(mean1 = 0,\n",
    "                 mean2 = 0,\n",
    "                 var1 = 1,\n",
    "                 var2 = 1,\n",
    "                 num_dims = 100,\n",
    "                 num_pts = 1000,\n",
    "                 seed = 0,\n",
    "                 c = 1000,\n",
    "                 verbose = False,\n",
    "                 distribution = 'gaussian'):\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "    d = np.random.normal(0,1,[num_dims])\n",
    "    d /= np.linalg.norm(d)\n",
    "    \n",
    "    if distribution == 'gaussian':\n",
    "        X = np.concatenate( [np.random.normal(mean1, var1, [num_pts, num_dims]), np.random.normal(mean2, var2, [num_pts, num_dims])], axis = 0)\n",
    "        Y = torch.Tensor(X + np.concatenate([np.zeros([num_pts, num_dims]), c*np.matmul(np.ones([num_pts,1]), d.reshape([1,num_dims]))], axis = 0)).to(cuda)\n",
    "    elif distribution == 'uniform':\n",
    "        # in this case var = side and mean = center\n",
    "        X = np.concatenate([var1*(np.random.rand(num_pts, num_dims)-var1*0.5*np.ones([num_pts,num_dims]))+mean1*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1), var2*(np.random.rand(num_pts, num_dims)-var2*0.5*np.ones([num_pts,num_dims]))+mean2*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1)], axis = 0)\n",
    "        Y = torch.Tensor(X + np.concatenate([np.zeros([num_pts, num_dims]), c*np.matmul(np.ones([num_pts,1]), d.reshape([1,num_dims]))], axis = 0)).to(cuda)\n",
    "    \n",
    "    X = torch.Tensor(X).to(cuda)\n",
    "    \n",
    "    CKA = LinCKA2()\n",
    "    if verbose:\n",
    "        return CKA(X,Y).item(), torch.where(X==Y)\n",
    "    else:\n",
    "        return CKA(X,Y).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Translations (multiple seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed 0\n",
      "seed 1\n",
      "seed 2\n",
      "seed 3\n",
      "seed 4\n",
      "seed 5\n",
      "seed 6\n",
      "seed 7\n",
      "seed 8\n",
      "seed 9\n"
     ]
    }
   ],
   "source": [
    "num_pts = 10000\n",
    "num_dims = 1000\n",
    "num_seeds = 10\n",
    "c_list = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]\n",
    "\n",
    "diff = []\n",
    "data = np.zeros([num_seeds, len(c_list)])\n",
    "for seed in range(num_seeds):\n",
    "    print(f'seed {seed}')\n",
    "    for i, c in enumerate(c_list):\n",
    "        data[seed, i], v = get_cka_test(mean2=1.1, num_dims = num_dims, num_pts = num_pts, c = c, seed = seed, distribution = 'uniform', verbose = True)\n",
    "        diff.append(v)\n",
    "            \n",
    "np.save('two_cubes_exp_lincka2_means_0_1.1_{}seeds_v2.npy'.format(num_seeds), data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# APPENDIX: Invertible linear transformations (with cubes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inv_lin_cka_cubes(mean1 = 0,\n",
    "                var1 = 1,\n",
    "                mean2 = 0,\n",
    "                var2 = 1,\n",
    "                transform_mean = 0,\n",
    "                transform_var = 1,\n",
    "                num_dims = 200,\n",
    "                num_pts = 10000,\n",
    "                seed = 0,\n",
    "                distribution = 'gaussian'):\n",
    "    \n",
    "    np.random.seed(seed)\n",
    "    rotation_matrix = np.random.normal(transform_mean, transform_var,[num_dims, num_dims])\n",
    "    while np.linalg.cond(rotation_matrix) >= 1/sys.float_info.epsilon:\n",
    "        rotation_matrix = np.random.normal(transform_mean, transform_var,[num_dims, num_dims])\n",
    "\n",
    "    cuda = torch.device('cuda')\n",
    "    rotation_matrix = torch.Tensor(rotation_matrix).to(cuda)\n",
    "    \n",
    "    if distribution == 'gaussian':\n",
    "        X = np.random.normal(mean1, var1, [int(num_pts*2), num_dims])\n",
    "    elif distribution == 'uniform':\n",
    "        # in this case var = side and mean = center\n",
    "        X = np.concatenate([var1*(np.random.rand(num_pts, num_dims)-var1*0.5*np.ones([num_pts,num_dims]))+mean1*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1), var2*(np.random.rand(num_pts, num_dims)-var2*0.5*np.ones([num_pts,num_dims]))+mean2*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1)], axis = 0)\n",
    "\n",
    "    X = torch.Tensor(X).to(cuda)\n",
    "    Y = torch.mm(X,rotation_matrix)\n",
    "\n",
    "    CKA = LinCKA2()\n",
    "    return CKA(X,Y).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed 0\n",
      "seed 1\n",
      "seed 2\n",
      "seed 3\n",
      "seed 4\n",
      "seed 5\n",
      "seed 6\n",
      "seed 7\n",
      "seed 8\n",
      "seed 9\n"
     ]
    }
   ],
   "source": [
    "num_pts = 10000#, 10000 and 20000 points makes it crash have other things running as well...\n",
    "num_dims = 500\n",
    "num_seeds = 10\n",
    "mu_list = [1, 5, 10, 25, 50, 100, 500, 1000, 5000, 1e4, 1e5]\n",
    "sigma_list = [1, 5, 10, 25, 50, 100, 500, 1000, 5000, 1e4, 1e5]\n",
    "\n",
    "\n",
    "data = np.zeros([num_seeds, len(mu_list), len(sigma_list)])\n",
    "for seed in range(num_seeds):\n",
    "    print(f'seed {seed}')\n",
    "    for i1, mu in enumerate(mu_list):\n",
    "        for i2, sigma in enumerate(sigma_list):\n",
    "            data[seed, i1, i2] = inv_lin_cka_cubes(mean1=0, mean2=1.1,num_dims = num_dims, num_pts = num_pts, transform_mean=mu, transform_var=sigma, seed = seed, distribution = 'uniform')\n",
    "            \n",
    "np.save('inv_lin_cka_fig_data_lincka2_10k_cubes_means0_1.1.npy', data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
