{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from utils import *\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import special_ortho_group\n",
    "import sys\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 Gaussians (or 2 cubes) translation exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda = torch.device('cuda')\n",
    "\n",
    "def get_cka_test(mean1 = 0,\n",
    "                 mean2 = 0,\n",
    "                 var1 = 1,\n",
    "                 var2 = 1,\n",
    "                 num_dims = 100,\n",
    "                 num_pts = 1000,\n",
    "                 seed = 0,\n",
    "                 c = 1000,\n",
    "                 verbose = False,\n",
    "                 distribution = 'gaussian',\n",
    "                 median = median):\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "    d = np.random.normal(0,1,[num_dims])\n",
    "    d /= np.linalg.norm(d)\n",
    "    \n",
    "    if distribution == 'gaussian':\n",
    "        X = np.concatenate( [np.random.normal(mean1, var1, [num_pts, num_dims]), np.random.normal(mean2, var2, [num_pts, num_dims])], axis = 0)\n",
    "        Y = torch.Tensor(X + np.concatenate([np.zeros([num_pts, num_dims]), c*np.matmul(np.ones([num_pts,1]), d.reshape([1,num_dims]))], axis = 0)).to(cuda)\n",
    "    elif distribution == 'uniform':\n",
    "        # in this case var = side and mean = center\n",
    "        X = np.concatenate([var1*(np.random.rand(num_pts, num_dims)-var1*0.5*np.ones([num_pts,num_dims]))+mean1*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1), var2*(np.random.rand(num_pts, num_dims)-var2*0.5*np.ones([num_pts,num_dims]))+mean2*np.concatenate([np.ones([num_pts,1]),np.zeros([num_pts,num_dims-1])], axis=1)], axis = 0)\n",
    "        Y = torch.Tensor(X + np.concatenate([np.zeros([num_pts, num_dims]), c*np.matmul(np.ones([num_pts,1]), d.reshape([1,num_dims]))], axis = 0)).to(cuda)\n",
    "    \n",
    "    X = torch.Tensor(X).to(cuda)\n",
    "    \n",
    "    CKA = rbfCKA(median=median)\n",
    "    if verbose:\n",
    "        return CKA(X,Y).item(), torch.where(X==Y)\n",
    "    else:\n",
    "        return CKA(X,Y).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiple seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed 0\n",
      "seed 1\n",
      "seed 2\n",
      "seed 3\n",
      "seed 4\n",
      "seed 5\n",
      "seed 6\n",
      "seed 7\n",
      "seed 8\n",
      "seed 9\n"
     ]
    }
   ],
   "source": [
    "num_pts = 10000\n",
    "num_dims = 1000\n",
    "num_seeds = 10\n",
    "c_list = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]\n",
    "\n",
    "for median in [0.2, 0.5, 1]:\n",
    "    diff = []\n",
    "    data = np.zeros([num_seeds, len(c_list)])\n",
    "    for seed in range(num_seeds):\n",
    "        print(f'seed {seed}')\n",
    "        for i, c in enumerate(c_list):\n",
    "            data[seed, i], v = get_cka_test(mean2=1.1, num_dims = num_dims, num_pts = num_pts, c = c, seed = seed, distribution = 'uniform', verbose = True, median = median)\n",
    "            diff.append(v)\n",
    "    \n",
    "    if median == 1:\n",
    "        np.save('two_cubes_exp_median_rbfcka2__means_0_1.1_{}seeds_v2.npy'.format(num_seeds), data)\n",
    "    else:\n",
    "        np.save('two_cubes_exp_median{}_rbfcka2__means_0_1.1_{}seeds_v2.npy'.format(median, num_seeds), data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
