{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
      "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "import time, os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "out_root = Path('out/consistency')\n",
    "makedirs(out_root, exist_ok=True)\n",
    "# rand = lambda : np.random.randint(np.iinfo(np.int32).max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://drive.google.com/uc?export=download&id=1JzA5iiS3qPrztVofQAjbb0N4xKdjOOyV\n"
     ]
    }
   ],
   "source": [
    "# StyleGAN2\n",
    "use_w = True\n",
    "inst = get_instrumented_model('StyleGAN2', 'ffhq', 'style', device, inst=inst, use_w=use_w)\n",
    "model = inst.model\n",
    "model.truncation = 1.0\n",
    "\n",
    "# ## StyleGAN2 - LSUN-CAT\n",
    "# use_w = True\n",
    "# dataset = 'cat'\n",
    "# inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0\n",
    "\n",
    "# ### StyleGAN1\n",
    "# use_w = True\n",
    "# dataset = 'ffhq'\n",
    "# inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0 # NOT IMPLEMENTED\n",
    "\n",
    "# ### StyleGAN1 - LSUN-CAT\n",
    "# use_w = True\n",
    "# dataset = 'cats'\n",
    "# inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0 # NOT IMPLEMENTED\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load directions\n",
    "#gs_dir = np.load('./global_directions/ganspace_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.\n",
    "gs_dir = np.load('./global_directions/ganspace_directions_ffhq_stylegan2_style-8.npy')#Note! Only ffhq is provided.\n",
    "gs_dir = torch.from_numpy(gs_dir).to(device)\n",
    "sf_dir = np.load('./global_directions/sefa_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.\n",
    "sf_dir = torch.from_numpy(sf_dir).to(device)\n",
    "class compare_basis_config:\n",
    "    n_samples = 50\n",
    "    seed = 0\n",
    "    subspace_dim = 2\n",
    "    rankEst = False\n",
    "    sv_thres_ratio =  0.001\n",
    "    last_layer_name = '8'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.autograd.set_grad_enabled(True)\n",
    "eval_config = compare_basis_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1', '2', '3', '4', '5', '6', '7', '8']\n"
     ]
    }
   ],
   "source": [
    "subLayerNames = [str(a) for a in range(1, 9)]\n",
    "sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]\n",
    "print(subLayerNames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grassmannian Metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Grassmannain Metric between two random O(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['1', '2', '3', '4', '5', '6', '7', '8'])\n",
      "dict_keys([0.0005, 0.001, 0.005, 0.01])\n"
     ]
    }
   ],
   "source": [
    "with open('./subnetwork_stats/LayerThres2rank_dict.dill', 'rb') as f:\n",
    "    LayerThres2rank_dict = pickle.load(f)\n",
    "    \n",
    "print(LayerThres2rank_dict.keys())\n",
    "print(LayerThres2rank_dict['8'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 100\n",
    "eval_config.rankEst = True\n",
    "return_samples = True\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    print(subLayerNames)\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        # mean of Local Rank \n",
    "        eval_config.subspace_dim = int(np.mean(LayerThres2rank_dict[eval_config.last_layer_name][eval_config.sv_thres_ratio]).round())\n",
    "\n",
    "        for metric_type in ['geodesic', 'proj']:\n",
    "            '''\n",
    "            Grassmannain Metric between two Random Orthogonal Matrix\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type} for dim {eval_config.subspace_dim}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_random_basis_consistency(model, eval_config, metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            #consistency_dict[last_layer_name][sv_thres_ratio].append(consistency)\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')\n",
    "        \n",
    "    # with open(f'./out/consistency/consistency_o(n)_till_{last_layer_name}_KN_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    #     pickle.dump(consistency_dict, f)\n",
    "with open(f'./out/consistency/consistency_o(n)_KN_n_{eval_config.n_samples}_samples.dill', 'wb') as f:\n",
    "        pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grassmannain Metric between two Random w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# '''StyleGAN2'''\n",
    "# subLayerNames = [str(a) for a in range(1, 9)]\n",
    "# sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]\n",
    "# print(subLayerNames)\n",
    "\n",
    "'''StyleGAN1'''\n",
    "subLayerNames = [f'dense{a}_act' for a in range(0, 8)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense0_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:49<00:00,  3.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 289.62\n",
      "Evaluating Layer dense0_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:43<00:00,  4.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 223.81\n",
      "Evaluating Layer dense0_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:18<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 138.02\n",
      "Evaluating Layer dense0_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:57<00:00,  8.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 117.44\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense1_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:30<00:00,  2.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 450.85\n",
      "Evaluating Layer dense1_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [06:21<00:00,  2.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 381.62\n",
      "Evaluating Layer dense1_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:39<00:00,  4.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 219.19\n",
      "Evaluating Layer dense1_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:02<00:00,  5.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 182.02\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense2_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:24<00:00,  3.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 324.61\n",
      "Evaluating Layer dense2_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:42<00:00,  3.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 282.93\n",
      "Evaluating Layer dense2_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:14<00:00,  5.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 194.94\n",
      "Evaluating Layer dense2_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:50<00:00,  5.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 170.7\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense3_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:11<00:00,  3.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 251.77\n",
      "Evaluating Layer dense3_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:46<00:00,  4.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 226.06\n",
      "Evaluating Layer dense3_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:53<00:00,  5.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 173.04\n",
      "Evaluating Layer dense3_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:35<00:00,  6.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 155.3\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense4_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:25<00:00,  4.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 205.07\n",
      "Evaluating Layer dense4_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:53<00:00,  5.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 173.57\n",
      "Evaluating Layer dense4_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:24<00:00,  6.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 144.05\n",
      "Evaluating Layer dense4_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:15<00:00,  7.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 135.36\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense5_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:47<00:00,  5.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 167.67\n",
      "Evaluating Layer dense5_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:39<00:00,  6.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 159.95\n",
      "Evaluating Layer dense5_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:29<00:00,  6.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 149.81\n",
      "Evaluating Layer dense5_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:13<00:00,  7.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 133.4\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense6_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:09<00:00,  7.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 129.68\n",
      "Evaluating Layer dense6_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:02<00:00,  8.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 122.55\n",
      "Evaluating Layer dense6_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:48<00:00,  9.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 108.76\n",
      "Evaluating Layer dense6_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:44<00:00,  9.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 104.39\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense7_act Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:02<00:00,  8.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 122.98\n",
      "Evaluating Layer dense7_act Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:57<00:00,  8.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 117.21\n",
      "Evaluating Layer dense7_act Thres 0.005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:46<00:00,  9.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 106.24\n",
      "Evaluating Layer dense7_act Thres 0.01 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:43<00:00,  9.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 103.02\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "eval_config.n_samples = 1000\n",
    "eval_config.rankEst = True\n",
    "return_samples = True\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    print(subLayerNames)\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        #for metric_type in ['geodesic', 'proj']:\n",
    "        for metric_type in ['geodesic']:\n",
    "            '''\n",
    "            Grassmannain Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency(model, eval_config, metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')    \n",
    "\n",
    "with open(f'./out/consistency/consistency_random_KN_n_{eval_config.n_samples}_samples_stylegan1-cat.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grassmannain Metric between two Close w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1', '2', '3', '4', '5', '6', '7', '8']\n",
      "Evaluating Layer 1 Thres 0.0005 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:42<00:00,  4.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 222.29\n",
      "Evaluating Layer 1 Thres 0.001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████████████                                                                 | 178/1000 [00:31<02:25,  5.65it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_33344\\1876538062.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     20\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     21\u001b[0m             \u001b[0mtimer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 22\u001b[1;33m             \u001b[0mconsistency\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mevaluate_basis_consistency_local\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_config\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meps\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0meps\u001b[0m\u001b[1;33m,\u001b[0m  \u001b[0mmetric_type\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmetric_type\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_samples\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_samples\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     23\u001b[0m             \u001b[0mconsistency_dict\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mlast_layer_name\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0msv_thres_ratio\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mmetric_type\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconsistency\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     24\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'Evaluation took {round(time.time() - timer, 2)}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Desktop\\Research\\localbasis_rank_estimate_final_dev_ver\\notebooks\\localbasis_utils.py\u001b[0m in \u001b[0;36mevaluate_basis_consistency_local\u001b[1;34m(model, eval_config, eps, metric_type, return_samples)\u001b[0m\n\u001b[0;32m    525\u001b[0m                                                                             \u001b[0mlast_layer_name\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0meval_config\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlast_layer_name\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    526\u001b[0m                                                                             \u001b[0mrankEst\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0meval_config\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrankEst\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 527\u001b[1;33m                                                                             sv_thres_ratio=eval_config.sv_thres_ratio)\n\u001b[0m\u001b[0;32m    528\u001b[0m         \u001b[0mlocal_basis_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mz_local_basis\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    529\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Desktop\\Research\\localbasis_rank_estimate_final_dev_ver\\notebooks\\localbasis_utils.py\u001b[0m in \u001b[0;36mget_random_local_basis\u001b[1;34m(model, random_state, noise, last_layer_name, rankEst, sv_thres_ratio, alpha)\u001b[0m\n\u001b[0;32m    136\u001b[0m     \u001b[1;34m''' Get local basis'''\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    137\u001b[0m     \u001b[1;31m# jacobian \\approx torch.mm(torch.mm(z_basis, torch.diag(s)), noise_basis.t())\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 138\u001b[1;33m     \u001b[0mz_basis\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnoise_basis\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mjacobian\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    139\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    140\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mrankEst\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "eval_config.n_samples = 1000\n",
    "eval_config.rankEst = True\n",
    "return_samples = True\n",
    "eps = 1e-1\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    print(subLayerNames)\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        #for metric_type in ['geodesic', 'proj']:\n",
    "        for metric_type in ['geodesic']:\n",
    "            '''\n",
    "            Grassmannain Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "            \n",
    "with open(f'./out/consistency/consistency_local_KN_eps_{eps}_n_{eval_config.n_samples}_samples_stylegan2-cat.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grassmannain Metric between random w and GANSpace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 1000\n",
    "eval_config.rankEst = True\n",
    "return_samples = True\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    print(subLayerNames)\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    gs_dir = np.load(f'./global_directions/ganspace_directions_ffhq_stylegan2_style-{last_layer_name}.npy')#Note! Only ffhq is provided.\n",
    "    gs_dir = torch.from_numpy(gs_dir).to(device)\n",
    "    global_basis = gs_dir.squeeze().t().cpu()\n",
    "\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "\n",
    "        for metric_type in ['geodesic', 'proj']:\n",
    "            '''\n",
    "            Grassmannain Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_to_global(model, eval_config, global_basis,  metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "    \n",
    "with open(f'./out/consistency/consistency_to_global_ganspace_KN_n_{eval_config.n_samples}_samples.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grassmannain Metric between random w and SeFa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_config.n_samples = 1000\n",
    "global_basis = sf_dir.squeeze().t().cpu()\n",
    "\n",
    "for metric_type in ['geodesic', 'proj']:\n",
    "    '''\n",
    "    Grassmannain Metric between two Random w\n",
    "    '''\n",
    "    consistency_list = []\n",
    "    time_list = []\n",
    "    \n",
    "    for dim in range(1, 51):\n",
    "        print(f'Evaluated dim {dim}')\n",
    "        time.sleep(0.2)\n",
    "        eval_config.subspace_dim = dim\n",
    "\n",
    "        timer = time.time()\n",
    "        consistency = evaluate_basis_consistency_to_global(model, eval_config, global_basis,  metric_type = metric_type)\n",
    "        consistency_list.append(consistency)\n",
    "        time_list.append(round(time.time() - timer, 2))\n",
    "        \n",
    "\n",
    "    plt.plot(consistency_list)\n",
    "    plt.savefig(f'./out/consistency/consistency_to_global_sefa_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.png')\n",
    "    plt.show()\n",
    "    \n",
    "    with open(f'./out/consistency/consistency_to_global_sefa_metric_{metric_type}_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "        pickle.dump(consistency_list, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablation study for eps in Close W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "eval_config.n_samples = 1000\n",
    "\n",
    "for metric_type in ['geodesic', 'proj']:\n",
    "    for eps in [1e-2, 1e-1, 0.25, 5e-1]:\n",
    "        consistency_list = []\n",
    "        time_list = []\n",
    "\n",
    "        for dim in range(1, 51):\n",
    "            print(f'Evaluated dim {dim}')\n",
    "            time.sleep(0.2)\n",
    "            eval_config.subspace_dim = dim\n",
    "\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type)\n",
    "            consistency_list.append(consistency)\n",
    "            time_list.append(round(time.time() - timer, 2))\n",
    "\n",
    "\n",
    "        plt.plot(consistency_list)\n",
    "        plt.savefig(f'./out/consistency/consistency_local_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.png')\n",
    "        plt.show()\n",
    "\n",
    "        with open(f'./out/consistency/consistency_local_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "            pickle.dump(consistency_list, f)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert dict to csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['consistency_local_KN_eps_0.1_n_1000_samples_stylegan1-cat.dill', 'consistency_random_KN_n_1000_samples_stylegan1-cat.dill']\n"
     ]
    }
   ],
   "source": [
    "dict_dir = './out/consistency/'\n",
    "dict_files = []\n",
    "for f in  os.listdir(dict_dir):\n",
    "    if 'stylegan1-cat' in f:\n",
    "        if ('samples' in f) and ('layer12' not in f) and (not f.endswith('csv')):\n",
    "            dict_files.append(f)\n",
    "print(dict_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "consistency_local_KN_eps_0.1_n_1000_samples_stylegan1-cat.dill\n",
      "consistency_random_KN_n_1000_samples_stylegan1-cat.dill\n"
     ]
    }
   ],
   "source": [
    "#consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "\n",
    "for dict_file in dict_files:\n",
    "    dict_path = os.path.join(dict_dir, dict_file)\n",
    "    with open(dict_path, 'rb') as f:\n",
    "        result_dict = pickle.load(f)\n",
    "        \n",
    "    csv_list = []\n",
    "    #print(result_dict.keys())\n",
    "    for last_layer_name, layer_dict in result_dict.items():\n",
    "        #print(layer_dict.keys())\n",
    "        for sv_thres_ratio, layer_thres_dict in layer_dict.items():\n",
    "            #print(layer_thres_dict.keys())\n",
    "            for metric_type, metric_samples in layer_thres_dict.items():\n",
    "                if metric_samples is None:\n",
    "                    print(f\"Metric None at :{dict_path} {[last_layer_name, sv_thres_ratio, metric_type, metric_samples]}\")\n",
    "                    continue\n",
    "                for sample in metric_samples:\n",
    "                    csv_list.append([last_layer_name, sv_thres_ratio, metric_type, sample])\n",
    "                #print(csv_list[:10])\n",
    "                \n",
    "    df = pd.DataFrame(csv_list, columns = ['last_layer_name', 'sv_thres_ratio', 'metric_type', 'score'])\n",
    "    df.to_csv(dict_path[:-4]+'csv')\n",
    "    print(dict_file)\n",
    "        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_path[:-4]+'csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Steifel Metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stiefel Metric between two Random w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_config.n_samples = 10\n",
    "eval_config.rankEst = True\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "return_samples = True\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            Stiefel Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency(model, eval_config, metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')    \n",
    "            \n",
    "# with open(f'./out/consistency/consistency_random_KN_stiefel_n_{eval_config.n_samples}_samples.dill', 'wb') as f:\n",
    "#     pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Stiefel Metric between two random O(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./subnetwork_stats/LayerThres2rank_dict.dill', 'rb') as f:\n",
    "    LayerThres2rank_dict = pickle.load(f)\n",
    "    \n",
    "print(LayerThres2rank_dict.keys())\n",
    "print(LayerThres2rank_dict['8'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 10\n",
    "eval_config.rankEst = True\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        # mean of Local Rank \n",
    "        eval_config.subspace_dim = int(np.mean(LayerThres2rank_dict[eval_config.last_layer_name][eval_config.sv_thres_ratio]).round())\n",
    "\n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            stiefel Metric between two Random Orthogonal Matrix\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type} for dim {eval_config.subspace_dim}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_random_basis_consistency(model, eval_config, metric_type = metric_type)\n",
    "            #print(consistency)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            #consistency_dict[last_layer_name][sv_thres_ratio].append(consistency)\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')\n",
    "        \n",
    "    # with open(f'./out/consistency/consistency_o(n)_till_{last_layer_name}_KN_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    #     pickle.dump(consistency_dict, f)\n",
    "# with open(f'./out/consistency/consistency_o(n)_KN_stiefel_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "#         pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stiefel Metric between two Close w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 100\n",
    "eval_config.rankEst = True\n",
    "eps = 1e-1\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            Stiefel Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "            \n",
    "with open(f'./out/consistency/consistency_local_KN_stiefel_eps_{eps}_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stiefel Metric between random w and GANSpace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 100\n",
    "eval_config.rankEst = True\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    gs_dir = np.load(f'./global_directions/ganspace_directions_ffhq_stylegan2_style-{last_layer_name}.npy')#Note! Only ffhq is provided.\n",
    "    gs_dir = torch.from_numpy(gs_dir).to(device)\n",
    "    global_basis = gs_dir.squeeze().t().cpu()\n",
    "\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "\n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            Stiefel Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_to_global(model, eval_config, global_basis,  metric_type = metric_type)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "    \n",
    "with open(f'./out/consistency/consistency_to_global_ganspace_KN_stiefel_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "local_basis",
   "language": "python",
   "name": "local_basis"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
