{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
      "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "import time, os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "out_root = Path('out/consistency')\n",
    "makedirs(out_root, exist_ok=True)\n",
    "# rand = lambda : np.random.randint(np.iinfo(np.int32).max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # StyleGAN2\n",
    "# use_w = True\n",
    "# inst = get_instrumented_model('StyleGAN2', 'ffhq', 'style', device, inst=inst, use_w=use_w)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0\n",
    "\n",
    "# ## StyleGAN2 - LSUN-CAT\n",
    "# use_w = True\n",
    "# dataset = 'cat'\n",
    "# inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0\n",
    "\n",
    "### StyleGAN1\n",
    "use_w = True\n",
    "dataset = 'ffhq'\n",
    "inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)\n",
    "model = inst.model\n",
    "model.truncation = 1.0 # NOT IMPLEMENTED\n",
    "\n",
    "# ### StyleGAN1 - LSUN-CAT\n",
    "# use_w = True\n",
    "# dataset = 'cats'\n",
    "# inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0 # NOT IMPLEMENTED\n",
    "\n",
    "# ### StyleGAN1 - LSUN-CAT\n",
    "# use_w = True\n",
    "# dataset = 'cats'\n",
    "# inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0 # NOT IMPLEMENTED\n",
    "\n",
    "# ## StyleGAN2 - LSUN-CAR\n",
    "# use_w = True\n",
    "# dataset = 'car'\n",
    "# inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0\n",
    "\n",
    "# ## StyleGAN2 - LSUN-horse\n",
    "# use_w = True\n",
    "# dataset = 'horse'\n",
    "# inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0\n",
    "\n",
    "# ## StyleGAN2 - LSUN-church\n",
    "# use_w = True\n",
    "# dataset = 'church'\n",
    "# inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0\n",
    "\n",
    "model_name = 'StyleGAN1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class compare_basis_config:\n",
    "    n_samples = 50\n",
    "    seed = 0\n",
    "    subspace_dim = 2\n",
    "    rankEst = False\n",
    "    sv_thres_ratio =  0.001\n",
    "    last_layer_name = '8'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.autograd.set_grad_enabled(True)\n",
    "eval_config = compare_basis_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1', '2', '3', '4', '5', '6', '7', '8']\n"
     ]
    }
   ],
   "source": [
    "subLayerNames = [str(a) for a in range(1, 9)]\n",
    "sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]\n",
    "print(subLayerNames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grassmannian Metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grassmannain Metric between two Random w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''StyleGAN2'''\n",
    "# subLayerNames = [str(a) for a in range(1, 9)]\n",
    "#sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]\n",
    "sv_thres_ratio_candi = [0.00005, 0.0001]\n",
    "# print(subLayerNames)\n",
    "\n",
    "'''StyleGAN1'''\n",
    "subLayerNames = [f'dense{a}_act' for a in range(0, 8)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense0_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:29<00:00,  1.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 569.58\n",
      "Evaluating Layer dense0_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:12<00:00,  2.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 492.33\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense1_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:19<00:00,  1.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 739.84\n",
      "Evaluating Layer dense1_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [11:44<00:00,  1.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 704.15\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense2_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:59<00:00,  1.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 599.14\n",
      "Evaluating Layer dense2_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:17<00:00,  1.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 557.42\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense3_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:19<00:00,  2.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 499.9\n",
      "Evaluating Layer dense3_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:40<00:00,  2.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 460.14\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense4_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [06:39<00:00,  2.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 399.8\n",
      "Evaluating Layer dense4_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [06:05<00:00,  2.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 365.81\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense5_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:19<00:00,  3.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 319.47\n",
      "Evaluating Layer dense5_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:46<00:00,  3.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 286.69\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense6_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:36<00:00,  3.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 276.28\n",
      "Evaluating Layer dense6_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:56<00:00,  4.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 236.29\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense7_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:55<00:00,  4.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 235.46\n",
      "Evaluating Layer dense7_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 217.15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "eval_config.n_samples = 1000\n",
    "eval_config.rankEst = True\n",
    "return_samples = True\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    print(subLayerNames)\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        #for metric_type in ['geodesic', 'proj']:\n",
    "        for metric_type in ['geodesic']:\n",
    "            '''\n",
    "            Grassmannain Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency(model, eval_config, metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')    \n",
    "\n",
    "with open(f'./out/consistency/consistency_random_KN_n_{eval_config.n_samples}_samples_{model_name}_smaller.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grassmannain Metric between two Close w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense0_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:54<00:00,  1.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 534.64\n",
      "Evaluating Layer dense0_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:43<00:00,  2.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 463.12\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense1_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:00<00:00,  1.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 720.12\n",
      "Evaluating Layer dense1_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [11:47<00:00,  1.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 707.11\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense2_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:57<00:00,  1.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 597.27\n",
      "Evaluating Layer dense2_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:20<00:00,  1.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 560.41\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense3_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:24<00:00,  1.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 504.14\n",
      "Evaluating Layer dense3_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:58<00:00,  2.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 478.15\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense4_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:08<00:00,  2.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 428.17\n",
      "Evaluating Layer dense4_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:56<00:00,  2.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 356.44\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense5_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:21<00:00,  3.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 321.04\n",
      "Evaluating Layer dense5_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:46<00:00,  3.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 286.73\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense6_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:14<00:00,  5.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 194.54\n",
      "Evaluating Layer dense6_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:52<00:00,  5.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 172.41\n",
      "['dense0_act', 'dense1_act', 'dense2_act', 'dense3_act', 'dense4_act', 'dense5_act', 'dense6_act', 'dense7_act']\n",
      "Evaluating Layer dense7_act Thres 5e-05 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:52<00:00,  5.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 172.79\n",
      "Evaluating Layer dense7_act Thres 0.0001 Metric geodesic\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:42<00:00,  6.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation took 162.1\n"
     ]
    }
   ],
   "source": [
    "eval_config.n_samples = 1000\n",
    "eval_config.rankEst = True\n",
    "return_samples = True\n",
    "eps = 1e-1\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    print(subLayerNames)\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        #for metric_type in ['geodesic', 'proj']:\n",
    "        for metric_type in ['geodesic']:\n",
    "            '''\n",
    "            Grassmannain Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "            \n",
    "with open(f'./out/consistency/consistency_local_KN_eps_{eps}_n_{eval_config.n_samples}_samples_{model_name}_smaller.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert dict to csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN1-cat_smaller.dill', 'consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN1_smaller.dill', 'consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN2-church_smaller.dill', 'consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN2-horse_smaller.dill', 'consistency_random_KN_n_1000_samples_StyleGAN1-cat_smaller.dill', 'consistency_random_KN_n_1000_samples_StyleGAN1_smaller.dill', 'consistency_random_KN_n_1000_samples_StyleGAN2-church_smaller.dill', 'consistency_random_KN_n_1000_samples_StyleGAN2-horse_smaller.dill']\n"
     ]
    }
   ],
   "source": [
    "dict_dir = './out/consistency/'\n",
    "dict_files = []\n",
    "for f in  os.listdir(dict_dir):\n",
    "    #if 'StyleGAN2' in f:\n",
    "    if 'smaller' in f:\n",
    "        if ('samples' in f) and ('layer12' not in f) and (not f.endswith('csv')):\n",
    "            dict_files.append(f)\n",
    "print(dict_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN1-cat_smaller.dill\n",
      "consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN1_smaller.dill\n",
      "consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN2-church_smaller.dill\n",
      "consistency_local_KN_eps_0.1_n_1000_samples_StyleGAN2-horse_smaller.dill\n",
      "consistency_random_KN_n_1000_samples_StyleGAN1-cat_smaller.dill\n",
      "consistency_random_KN_n_1000_samples_StyleGAN1_smaller.dill\n",
      "consistency_random_KN_n_1000_samples_StyleGAN2-church_smaller.dill\n",
      "consistency_random_KN_n_1000_samples_StyleGAN2-horse_smaller.dill\n"
     ]
    }
   ],
   "source": [
    "#consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "\n",
    "for dict_file in dict_files:\n",
    "    dict_path = os.path.join(dict_dir, dict_file)\n",
    "    with open(dict_path, 'rb') as f:\n",
    "        result_dict = pickle.load(f)\n",
    "        \n",
    "    csv_list = []\n",
    "    #print(result_dict.keys())\n",
    "    for last_layer_name, layer_dict in result_dict.items():\n",
    "        #print(layer_dict.keys())\n",
    "        for sv_thres_ratio, layer_thres_dict in layer_dict.items():\n",
    "            #print(layer_thres_dict.keys())\n",
    "            for metric_type, metric_samples in layer_thres_dict.items():\n",
    "                if metric_samples is None:\n",
    "                    print(f\"Metric None at :{dict_path} {[last_layer_name, sv_thres_ratio, metric_type, metric_samples]}\")\n",
    "                    continue\n",
    "                for sample in metric_samples:\n",
    "                    csv_list.append([last_layer_name, sv_thres_ratio, metric_type, sample])\n",
    "                #print(csv_list[:10])\n",
    "                \n",
    "    df = pd.DataFrame(csv_list, columns = ['last_layer_name', 'sv_thres_ratio', 'metric_type', 'score'])\n",
    "    df.to_csv(dict_path[:-4]+'csv')\n",
    "    print(dict_file)\n",
    "        \n",
    "            "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('newvae')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "92cd103f401c12a28b3e8548c09f6f8b155d74d48e2549a619b115437b656304"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
