{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
      "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "import time, os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "out_root = Path('out/consistency')\n",
    "makedirs(out_root, exist_ok=True)\n",
    "# rand = lambda : np.random.randint(np.iinfo(np.int32).max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading ..\\models\\checkpoints\\stylegan2\\stylegan2_ffhq_1024.pt\n"
     ]
    }
   ],
   "source": [
    "## StyleGAN2\n",
    "use_w = True\n",
    "inst = get_instrumented_model('StyleGAN2', 'ffhq', 'style', device, inst=inst, use_w=use_w)\n",
    "model = inst.model\n",
    "model.truncation = 1.0\n",
    "\n",
    "# ### StyleGAN1\n",
    "# use_w = True\n",
    "# dataset = 'ffhq'\n",
    "# inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)\n",
    "# model = inst.model\n",
    "# model.truncation = 1.0 # NOT IMPLEMENTED"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load directions\n",
    "#gs_dir = np.load('./global_directions/ganspace_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.\n",
    "gs_dir = np.load('./global_directions/ganspace_directions_ffhq_stylegan2_style-8.npy')#Note! Only ffhq is provided.\n",
    "gs_dir = torch.from_numpy(gs_dir).to(device)\n",
    "sf_dir = np.load('./global_directions/sefa_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.\n",
    "sf_dir = torch.from_numpy(sf_dir).to(device)\n",
    "class compare_basis_config:\n",
    "    n_samples = 50\n",
    "    seed = 0\n",
    "    subspace_dim = 2\n",
    "    rankEst = False\n",
    "    sv_thres_ratio =  0.001\n",
    "    last_layer_name = '8'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.autograd.set_grad_enabled(True)\n",
    "eval_config = compare_basis_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1', '2', '3', '4', '5', '6', '7', '8']\n"
     ]
    }
   ],
   "source": [
    "subLayerNames = [str(a) for a in range(1, 9)]\n",
    "sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]\n",
    "print(subLayerNames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert dict to csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dict_dir = './out/consistency/'\n",
    "dict_files = []\n",
    "for f in  os.listdir(dict_dir):\n",
    "    if 'stylegan1' in f:\n",
    "        if ('samples' in f) and ('layer12' not in f) and (not f.endswith('csv')):\n",
    "            dict_files.append(f)\n",
    "print(dict_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "\n",
    "for dict_file in dict_files:\n",
    "    dict_path = os.path.join(dict_dir, dict_file)\n",
    "    with open(dict_path, 'rb') as f:\n",
    "        result_dict = pickle.load(f)\n",
    "        \n",
    "    csv_list = []\n",
    "    #print(result_dict.keys())\n",
    "    for last_layer_name, layer_dict in result_dict.items():\n",
    "        #print(layer_dict.keys())\n",
    "        for sv_thres_ratio, layer_thres_dict in layer_dict.items():\n",
    "            #print(layer_thres_dict.keys())\n",
    "            for metric_type, metric_samples in layer_thres_dict.items():\n",
    "                if metric_samples is None:\n",
    "                    print(f\"Metric None at :{dict_path} {[last_layer_name, sv_thres_ratio, metric_type, metric_samples]}\")\n",
    "                    continue\n",
    "                for sample in metric_samples:\n",
    "                    csv_list.append([last_layer_name, sv_thres_ratio, metric_type, sample])\n",
    "                #print(csv_list[:10])\n",
    "                \n",
    "    df = pd.DataFrame(csv_list, columns = ['last_layer_name', 'sv_thres_ratio', 'metric_type', 'score'])\n",
    "    df.to_csv(dict_path[:-4]+'csv')\n",
    "    print(dict_file)\n",
    "        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_path[:-4]+'csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Steifel Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stiefel Metric between two Random w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Layer 1 Thres 0.01 Metric stiefel\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                           | 0/10 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Can't instantiate abstract class RealStiefel with abstract methods inner_product, projection, random_point, random_tangent_vector, zero_vector",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_32908\\4223644158.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     18\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     19\u001b[0m             \u001b[0mtimer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m             \u001b[0mconsistency\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mevaluate_basis_consistency\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_config\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmetric_type\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmetric_type\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_samples\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_samples\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     21\u001b[0m             \u001b[0mconsistency_dict\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mlast_layer_name\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0msv_thres_ratio\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mmetric_type\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconsistency\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     22\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'Evaluation took {round(time.time() - timer, 2)}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Desktop\\Research\\localbasis_rank_estimate_final_dev_ver\\notebooks\\localbasis_utils.py\u001b[0m in \u001b[0;36mevaluate_basis_consistency\u001b[1;34m(model, eval_config, metric_type, return_samples)\u001b[0m\n\u001b[0;32m    503\u001b[0m             \u001b[1;31m#print('Trying stiefel')\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    504\u001b[0m             metric = compute_stiefel_metric(local_basis_list[0], local_basis_list[1], d = evaluated_dim,\n\u001b[1;32m--> 505\u001b[1;33m                                metric_type = metric_type, normalize_dim = eval_config.rankEst)\n\u001b[0m\u001b[0;32m    506\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    507\u001b[0m             metric = compute_grsm_metric(local_basis_list[0], local_basis_list[1], d = evaluated_dim,\n",
      "\u001b[1;32m~\\Desktop\\Research\\localbasis_rank_estimate_final_dev_ver\\notebooks\\localbasis_utils.py\u001b[0m in \u001b[0;36mcompute_stiefel_metric\u001b[1;34m(local_basis_1, local_basis_2, d, metric_type, normalize_dim)\u001b[0m\n\u001b[0;32m    461\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcompute_stiefel_metric\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlocal_basis_1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlocal_basis_2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0md\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmetric_type\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'stiefel'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnormalize_dim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    462\u001b[0m     \u001b[1;32massert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmetric_type\u001b[0m \u001b[1;32min\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;34m'stiefel'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 463\u001b[1;33m     \u001b[0mmetric\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_metric_by_stiefel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlocal_basis_1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlocal_basis_2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msubspace_dim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0md\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnormalize_dim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnormalize_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    464\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mmetric\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    465\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Desktop\\Research\\localbasis_rank_estimate_final_dev_ver\\notebooks\\localbasis_utils.py\u001b[0m in \u001b[0;36m_metric_by_stiefel\u001b[1;34m(local_basis_1, local_basis_2, subspace_dim, normalize_dim)\u001b[0m\n\u001b[0;32m    469\u001b[0m     \u001b[0malpha\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m.8\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    470\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 471\u001b[1;33m     \u001b[0mmfd\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mRealStiefel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mambient_dim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msubspace_dim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0malpha\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0malpha\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_stats\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_method\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'trust-krylov'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    472\u001b[0m     \u001b[1;34m'''Calculate Stiefel Distance'''\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    473\u001b[0m     \u001b[1;31m#print(f'Evaluating Stiefel for dim {subspace_dim}')\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: Can't instantiate abstract class RealStiefel with abstract methods inner_product, projection, random_point, random_tangent_vector, zero_vector"
     ]
    }
   ],
   "source": [
    "eval_config.n_samples = 10\n",
    "eval_config.rankEst = True\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "return_samples = True\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            Stiefel Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency(model, eval_config, metric_type = metric_type, return_samples=return_samples)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')    \n",
    "            \n",
    "# with open(f'./out/consistency/consistency_random_KN_stiefel_n_{eval_config.n_samples}_samples.dill', 'wb') as f:\n",
    "#     pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Stiefel Metric between two random O(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./subnetwork_stats/LayerThres2rank_dict.dill', 'rb') as f:\n",
    "    LayerThres2rank_dict = pickle.load(f)\n",
    "    \n",
    "print(LayerThres2rank_dict.keys())\n",
    "print(LayerThres2rank_dict['8'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 10\n",
    "eval_config.rankEst = True\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        # mean of Local Rank \n",
    "        eval_config.subspace_dim = int(np.mean(LayerThres2rank_dict[eval_config.last_layer_name][eval_config.sv_thres_ratio]).round())\n",
    "\n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            stiefel Metric between two Random Orthogonal Matrix\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type} for dim {eval_config.subspace_dim}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_random_basis_consistency(model, eval_config, metric_type = metric_type)\n",
    "            #print(consistency)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            #consistency_dict[last_layer_name][sv_thres_ratio].append(consistency)\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}')\n",
    "        \n",
    "    # with open(f'./out/consistency/consistency_o(n)_till_{last_layer_name}_KN_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    #     pickle.dump(consistency_dict, f)\n",
    "# with open(f'./out/consistency/consistency_o(n)_KN_stiefel_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "#         pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stiefel Metric between two Close w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 100\n",
    "eval_config.rankEst = True\n",
    "eps = 1e-1\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "        \n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            Stiefel Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "            \n",
    "with open(f'./out/consistency/consistency_local_KN_stiefel_eps_{eps}_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stiefel Metric between random w and GANSpace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_config.n_samples = 100\n",
    "eval_config.rankEst = True\n",
    "sv_thres_ratio_candi_select = [0.01]\n",
    "\n",
    "consistency_dict = {}\n",
    "for last_layer_name in subLayerNames:\n",
    "    consistency_dict[last_layer_name] = {}\n",
    "    gs_dir = np.load(f'./global_directions/ganspace_directions_ffhq_stylegan2_style-{last_layer_name}.npy')#Note! Only ffhq is provided.\n",
    "    gs_dir = torch.from_numpy(gs_dir).to(device)\n",
    "    global_basis = gs_dir.squeeze().t().cpu()\n",
    "\n",
    "    for sv_thres_ratio in sv_thres_ratio_candi_select:\n",
    "        consistency_dict[last_layer_name][sv_thres_ratio] = {}\n",
    "        eval_config.sv_thres_ratio = sv_thres_ratio\n",
    "        eval_config.last_layer_name = last_layer_name\n",
    "\n",
    "        for metric_type in ['stiefel']:\n",
    "            '''\n",
    "            Stiefel Metric between two Random w\n",
    "            '''\n",
    "            print(f'Evaluating Layer {last_layer_name} Thres {sv_thres_ratio} Metric {metric_type}')\n",
    "            timer = time.time()\n",
    "            consistency = evaluate_basis_consistency_to_global(model, eval_config, global_basis,  metric_type = metric_type)\n",
    "            consistency_dict[last_layer_name][sv_thres_ratio][metric_type] = consistency\n",
    "            print(f'Evaluation took {round(time.time() - timer, 2)}') \n",
    "    \n",
    "with open(f'./out/consistency/consistency_to_global_ganspace_KN_stiefel_n_{eval_config.n_samples}.dill', 'wb') as f:\n",
    "    pickle.dump(consistency_dict, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "local_basis",
   "language": "python",
   "name": "local_basis"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
