{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fcbe0c08ab0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from labproject.data import get_dataset, DATASETS\n",
    "from labproject.metrics import METRICS\n",
    "import numpy as np\n",
    "\n",
    "from labproject.metrics.utils import get_metric\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'random': <function labproject.data.random_dataset(n=1000, d=10)>,\n",
       " 'multivariate_normal': <function labproject.data.multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None)>,\n",
       " 'toy_2d': <function labproject.data.toy_2d(n=1000, d=2)>,\n",
       " 'cifar10_train': <function labproject.data.cifar10_train(n=1000, d=2048, save_path='data', device='cpu', return_labels=False)>,\n",
       " 'cifar10_test': <function labproject.data.cifar10_test(n=1000, d=2048, save_path='data', device='cpu', return_labels=False)>,\n",
       " 'imagenet_real_embeddings': <function labproject.data.imagenet_real_embeddings(n=1000, d=2048)>,\n",
       " 'imagenet_uncond_embeddings': <function labproject.data.imagenet_uncond_embeddings(n=1000, d=2048)>,\n",
       " 'imagenet_unconditional_model_embedding': <function labproject.data.imagenet_unconditional_model_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_test_embedding': <function labproject.data.imagenet_test_embedding(n, d=2048, device='cpu', save_path='data')>,\n",
       " 'imagenet_validation_embedding': <function labproject.data.imagenet_validation_embedding(n, d=2048, device='cpu', save_path='data')>,\n",
       " 'imagenet_conditional_model': <function labproject.data.imagenet_conditional_model(n, d=2048, label: Optional[int] = None, device='cpu', permute_if_no_label=True, save_path='data')>,\n",
       " 'imagenet_biggan_embedding': <function labproject.data.imagenet_biggan_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_sdv4_embedding': <function labproject.data.imagenet_sdv4_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_sdv5_embedding': <function labproject.data.imagenet_sdv5_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_vqdm_embedding': <function labproject.data.imagenet_vqdm_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_wukong_embedding': <function labproject.data.imagenet_wukong_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_adm_embedding': <function labproject.data.imagenet_adm_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_midjourney_embedding': <function labproject.data.imagenet_midjourney_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_cs1_embedding': <function labproject.data.imagenet_cs1_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_cs3_embedding': <function labproject.data.imagenet_cs3_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_cs10_embedding': <function labproject.data.imagenet_cs10_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,\n",
       " 'imagenet_cs100_embedding': <function labproject.data.imagenet_cs10_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "DATASETS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'mmd_rbf': <function labproject.metrics.MMD_torch.compute_rbf_mmd(x, y, bandwidth=1.0)>,\n",
       " 'mmd_rbf_median_heuristic': <function labproject.metrics.MMD_torch.compute_rbf_mmd_median_heuristic(x, y)>,\n",
       " 'mmd_rbf_auto': <function labproject.metrics.MMD_torch.compute_rbf_mmd_auto(x, y, bandwidth=1.0)>,\n",
       " 'mmd_polynomial': <function labproject.metrics.MMD_torch.compute_polynomial_mmd(x, y, degree=2, bias=0)>,\n",
       " 'mmd_linear_naive': <function labproject.metrics.MMD_torch.compute_linear_mmd_naive(x, y)>,\n",
       " 'mmd_linear': <function labproject.metrics.MMD_torch.compute_linear_mmd(x, y)>,\n",
       " 'mmd_energy': <function labproject.metrics.MMD_torch.compute_energy_mmd(x, y)>,\n",
       " 'c2st_nn': <function labproject.metrics.c2st.c2st_nn(X: torch.Tensor, Y: torch.Tensor, seed: int = 1, n_folds: int = 5, metric: str = 'accuracy', z_score: bool = True, activation: Literal['identity', 'logistic', 'tanh', 'relu'] = 'relu', clf_kwargs: dict[str, typing.Any] = {}) -> torch.Tensor>,\n",
       " 'c2st_rf': <function labproject.metrics.c2st.c2st_rf(X: torch.Tensor, Y: torch.Tensor, seed: int = 1, n_folds: int = 5, metric: str = 'accuracy', z_score: bool = True, n_estimators: int = 100, clf_kwargs: dict[str, typing.Any] = {}) -> torch.Tensor>,\n",
       " 'c2st_knn': <function labproject.metrics.c2st.c2st_knn(X: torch.Tensor, Y: torch.Tensor, seed: int = 1, n_folds: int = 5, metric: str = 'accuracy', z_score: bool = True, n_neighbors: int = 5, clf_kwargs: dict = {}) -> torch.Tensor>,\n",
       " 'gaussian_kl_divergence': <function labproject.metrics.gaussian_kl.gaussian_kl_divergence(real_samples: torch.Tensor, fake_samples: torch.Tensor) -> torch.Tensor>,\n",
       " 'wasserstein_gauss_squared': <function labproject.metrics.gaussian_squared_wasserstein.gaussian_squared_w2_distance(real_samples: torch.Tensor, fake_samples: torch.Tensor, real_mu=None, real_cov=None) -> torch.Tensor>,\n",
       " 'sliced_wasserstein': <function labproject.metrics.sliced_wasserstein.sliced_wasserstein_distance(encoded_samples: torch.Tensor, distribution_samples: torch.Tensor, num_projections: int = 50, p: int = 2, device: str = 'cpu') -> torch.Tensor>,\n",
       " 'wasserstein_kuhn': <function labproject.metrics.wasserstein_kuhn.wasserstein_kuhn(x: torch.Tensor, y: torch.Tensor, norm: Union[Callable, str, int] = 2) -> torch.Tensor>,\n",
       " 'wasserstein_sinkhorn': <function labproject.metrics.wasserstein_sinkhorn.sinkhorn_loss(x: torch.Tensor, y: torch.Tensor, epsilon: float = 0.001, niter: int = 1000, p: int = 2) -> torch.Tensor>}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "METRICS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_fn = get_metric(\"wasserstein_gauss_squared\")\n",
    "metric_fn2 = get_metric(\"sliced_wasserstein\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"imagenet_unconditional_model_embedding\",\"imagenet_cs1_embedding\", \"imagenet_cs10_embedding\", \"imagenet_biggan_embedding\", \"imagenet_sdv4_embedding\", \"imagenet_sdv5_embedding\", \"imagenet_vqdm_embedding\", \"imagenet_wukong_embedding\", \"imagenet_adm_embedding\", \"imagenet_midjourney_embedding\"]\n",
    "metrics = [\"wasserstein_gauss_squared\", \"sliced_wasserstein\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "testset_fn = get_dataset(\"imagenet_test_embedding\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "testset = testset_fn(100_000, 2048)\n",
    "idx = torch.randperm(len(testset))\n",
    "testset = testset[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_fid = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = metrics[0]\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn)\n",
    "        metric_values.append(m)\n",
    "    results_fid[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([6.2195807, 6.1363096, 6.1779237, 6.1323247, 6.1145873],\n",
       "       dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([6.430565 , 6.3383527, 6.3908176, 6.392403 , 6.3483315],\n",
       "       dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([6.9664445, 6.934556 , 6.985526 , 6.980446 , 6.945889 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([12.701342, 12.732329, 12.704566, 12.742165, 12.572661],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([17.149105, 17.233383, 17.262657, 17.13218 , 17.054472],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([17.27084 , 17.254538, 17.438782, 17.217953, 17.170486],\n",
       "       dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([11.207735, 11.090878, 11.292017, 11.206196, 11.167131],\n",
       "       dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([18.644157, 18.674633, 18.625652, 18.643795, 18.433655],\n",
       "       dtype=float32),\n",
       " 'imagenet_adm_embedding': array([12.630351 , 12.518478 , 12.667346 , 12.5699835, 12.462262 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([17.360374, 17.406929, 17.517849, 17.314253, 17.256212],\n",
       "       dtype=float32)}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_fid.npy\", results_fid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_sw = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_midjourney_embedding\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"sliced_wasserstein\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets[-1:]:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn, num_projections=5000)\n",
    "        metric_values.append(m)\n",
    "    results_sw[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([0.02358142, 0.02302841, 0.02342947, 0.02363985, 0.02274173],\n",
       "       dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([0.02153669, 0.02145257, 0.02175897, 0.02185666, 0.02145214],\n",
       "       dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([0.02553654, 0.02503867, 0.02539671, 0.0255399 , 0.02508791],\n",
       "       dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([0.05029042, 0.05024597, 0.05106073, 0.05075597, 0.05070167],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([0.05571224, 0.05621962, 0.05671528, 0.056228  , 0.05603858],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([0.05642236, 0.05595   , 0.0562194 , 0.05640132, 0.05647483],\n",
       "       dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([0.04172998, 0.04103179, 0.04130031, 0.04112783, 0.04078227],\n",
       "       dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([0.05560957, 0.05517614, 0.05603101, 0.05579789, 0.05510735],\n",
       "       dtype=float32),\n",
       " 'imagenet_adm_embedding': array([0.05062867, 0.05020323, 0.05051676, 0.05036184, 0.05050731],\n",
       "       dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([0.0481902 , 0.04775954, 0.04859642, 0.0480745 , 0.04758868],\n",
       "       dtype=float32)}"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_sw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_sw.npy\", results_sw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_mmd_rbf64 = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_unconditional_model_embedding\n",
      "Starting  imagenet_cs1_embedding\n",
      "Starting  imagenet_cs10_embedding\n",
      "Starting  imagenet_biggan_embedding\n",
      "Starting  imagenet_sdv4_embedding\n",
      "Starting  imagenet_sdv5_embedding\n",
      "Starting  imagenet_vqdm_embedding\n",
      "Starting  imagenet_wukong_embedding\n",
      "Starting  imagenet_adm_embedding\n",
      "Starting  imagenet_midjourney_embedding\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"mmd_rbf\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn, bandwidth=64.0)\n",
    "        metric_values.append(m)\n",
    "    results_mmd_rbf64[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([7.1763992e-05, 6.2704086e-05, 7.2360039e-05, 6.4253807e-05,\n",
       "        6.0319901e-05], dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([7.1883202e-05, 5.7458878e-05, 5.9604645e-05, 6.4969063e-05,\n",
       "        5.7697296e-05], dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([9.0479851e-05, 8.4877014e-05, 8.2612038e-05, 8.7141991e-05,\n",
       "        8.1181526e-05], dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([0.00019705, 0.00018573, 0.00018668, 0.00018907, 0.00017405],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([0.00021684, 0.00020289, 0.00020945, 0.00021148, 0.00020015],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([0.00021875, 0.00020254, 0.00021124, 0.00021267, 0.00020623],\n",
       "       dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([0.00015748, 0.0001334 , 0.0001483 , 0.00015378, 0.00014079],\n",
       "       dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([0.0002085 , 0.00019038, 0.00019991, 0.00020087, 0.00018454],\n",
       "       dtype=float32),\n",
       " 'imagenet_adm_embedding': array([0.00020254, 0.00019121, 0.00019884, 0.00019133, 0.00017893],\n",
       "       dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([0.00018728, 0.00016236, 0.00017667, 0.00017893, 0.00017178],\n",
       "       dtype=float32)}"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_mmd_rbf64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_mmd_rbf64.npy\", results_mmd_rbf64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_mmd_lin = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_unconditional_model_embedding\n",
      "Starting  imagenet_cs1_embedding\n",
      "Starting  imagenet_cs10_embedding\n",
      "Starting  imagenet_biggan_embedding\n",
      "Starting  imagenet_sdv4_embedding\n",
      "Starting  imagenet_sdv5_embedding\n",
      "Starting  imagenet_vqdm_embedding\n",
      "Starting  imagenet_wukong_embedding\n",
      "Starting  imagenet_adm_embedding\n",
      "Starting  imagenet_midjourney_embedding\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"mmd_linear\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn)\n",
    "        metric_values.append(m)\n",
    "    results_mmd_lin[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_mmd_lin.npy\", results_mmd_lin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([0.27161348, 0.23375021, 0.27461752, 0.23648134, 0.22176446],\n",
       "       dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([0.2784213 , 0.2160877 , 0.2256283 , 0.24824372, 0.21793067],\n",
       "       dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([0.3461548 , 0.32180998, 0.31091288, 0.33070827, 0.30549592],\n",
       "       dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([0.6735612, 0.6250146, 0.6302967, 0.6348198, 0.574922 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([0.7234071 , 0.6639475 , 0.6890489 , 0.69806087, 0.6532722 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([0.73029166, 0.66163504, 0.6962174 , 0.70317215, 0.6768427 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([0.55959   , 0.45710978, 0.51890296, 0.5427842 , 0.4893559 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([0.6912052 , 0.61376953, 0.6529302 , 0.6583701 , 0.5910167 ],\n",
       "       dtype=float32),\n",
       " 'imagenet_adm_embedding': array([0.6968005, 0.6499678, 0.6796049, 0.6463373, 0.5955869],\n",
       "       dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([0.6486083 , 0.54306364, 0.6036365 , 0.61384624, 0.583929  ],\n",
       "       dtype=float32)}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_mmd_lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_mmd_poly_kid = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_unconditional_model_embedding\n",
      "Starting  imagenet_cs1_embedding\n",
      "Starting  imagenet_cs10_embedding\n",
      "Starting  imagenet_biggan_embedding\n",
      "Starting  imagenet_sdv4_embedding\n",
      "Starting  imagenet_sdv5_embedding\n",
      "Starting  imagenet_vqdm_embedding\n",
      "Starting  imagenet_wukong_embedding\n",
      "Starting  imagenet_adm_embedding\n",
      "Starting  imagenet_midjourney_embedding\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"mmd_polynomial\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn, degree=3, bias=1.)\n",
    "        metric_values.append(m)\n",
    "    results_mmd_poly_kid[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([13181., 10870., 12008., 10130.,  9628.], dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([21638., 15387., 14466., 17863., 15719.], dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([17963., 15701., 13769., 15598., 14699.], dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([33609., 29164., 29311., 30098., 26772.], dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([39859., 34397., 34759., 37063., 34212.], dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([39529., 34206., 34731., 36771., 35564.], dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([25483., 19751., 22153., 24993., 22030.], dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([42367., 35982., 37357., 38585., 34325.], dtype=float32),\n",
       " 'imagenet_adm_embedding': array([37571., 34480., 34681., 33519., 29859.], dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([40835., 32620., 35570., 36513., 35667.], dtype=float32)}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_mmd_poly_kid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_mmd_poly_kid.npy\", results_mmd_poly_kid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_c2st_knn = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_unconditional_model_embedding\n",
      "Starting  imagenet_cs1_embedding\n",
      "Starting  imagenet_cs10_embedding\n",
      "Starting  imagenet_biggan_embedding\n",
      "Starting  imagenet_sdv4_embedding\n",
      "Starting  imagenet_sdv5_embedding\n",
      "Starting  imagenet_vqdm_embedding\n",
      "Starting  imagenet_wukong_embedding\n",
      "Starting  imagenet_adm_embedding\n",
      "Starting  imagenet_midjourney_embedding\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"c2st_knn\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn, n_folds=2)\n",
    "        metric_values.append(m)\n",
    "    results_c2st_knn[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([[0.65345 ],\n",
       "        [0.64765 ],\n",
       "        [0.650725],\n",
       "        [0.652725],\n",
       "        [0.652275]], dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([[0.63035 ],\n",
       "        [0.62975 ],\n",
       "        [0.63135 ],\n",
       "        [0.629375],\n",
       "        [0.63405 ]], dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([[0.6517 ],\n",
       "        [0.6506 ],\n",
       "        [0.6573 ],\n",
       "        [0.654  ],\n",
       "        [0.65605]], dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([[0.75415 ],\n",
       "        [0.749725],\n",
       "        [0.7522  ],\n",
       "        [0.750975],\n",
       "        [0.754475]], dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([[0.7913  ],\n",
       "        [0.78645 ],\n",
       "        [0.78665 ],\n",
       "        [0.78685 ],\n",
       "        [0.790925]], dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([[0.7921  ],\n",
       "        [0.786525],\n",
       "        [0.7872  ],\n",
       "        [0.790975],\n",
       "        [0.792425]], dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([[0.774775],\n",
       "        [0.76665 ],\n",
       "        [0.76835 ],\n",
       "        [0.773775],\n",
       "        [0.77215 ]], dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([[0.790025],\n",
       "        [0.786625],\n",
       "        [0.7873  ],\n",
       "        [0.7881  ],\n",
       "        [0.790225]], dtype=float32),\n",
       " 'imagenet_adm_embedding': array([[0.76605 ],\n",
       "        [0.7602  ],\n",
       "        [0.763725],\n",
       "        [0.760025],\n",
       "        [0.7645  ]], dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([[0.796975],\n",
       "        [0.791325],\n",
       "        [0.791325],\n",
       "        [0.79765 ],\n",
       "        [0.79655 ]], dtype=float32)}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_c2st_knn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_c2st_knn.npy\", results_c2st_knn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_c2st_nn = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_unconditional_model_embedding\n",
      "Starting  imagenet_cs1_embedding\n",
      "Starting  imagenet_cs10_embedding\n",
      "Starting  imagenet_biggan_embedding\n",
      "Starting  imagenet_sdv4_embedding\n",
      "Starting  imagenet_sdv5_embedding\n",
      "Starting  imagenet_vqdm_embedding\n",
      "Starting  imagenet_wukong_embedding\n",
      "Starting  imagenet_adm_embedding\n",
      "Starting  imagenet_midjourney_embedding\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"c2st_nn\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn, n_folds=2)\n",
    "        metric_values.append(m)\n",
    "    results_c2st_nn[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_c2st_nn.npy\", results_c2st_nn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'imagenet_unconditional_model_embedding': array([[0.7191  ],\n",
       "        [0.71785 ],\n",
       "        [0.72095 ],\n",
       "        [0.716875],\n",
       "        [0.7206  ]], dtype=float32),\n",
       " 'imagenet_cs1_embedding': array([[0.771325],\n",
       "        [0.761475],\n",
       "        [0.771525],\n",
       "        [0.767575],\n",
       "        [0.7738  ]], dtype=float32),\n",
       " 'imagenet_cs10_embedding': array([[0.7608  ],\n",
       "        [0.7621  ],\n",
       "        [0.764   ],\n",
       "        [0.7643  ],\n",
       "        [0.757775]], dtype=float32),\n",
       " 'imagenet_biggan_embedding': array([[0.8693  ],\n",
       "        [0.864775],\n",
       "        [0.862925],\n",
       "        [0.861475],\n",
       "        [0.854925]], dtype=float32),\n",
       " 'imagenet_sdv4_embedding': array([[0.92205 ],\n",
       "        [0.920025],\n",
       "        [0.932025],\n",
       "        [0.92415 ],\n",
       "        [0.92075 ]], dtype=float32),\n",
       " 'imagenet_sdv5_embedding': array([[0.920525],\n",
       "        [0.930475],\n",
       "        [0.923075],\n",
       "        [0.920125],\n",
       "        [0.92495 ]], dtype=float32),\n",
       " 'imagenet_vqdm_embedding': array([[0.8448 ],\n",
       "        [0.8455 ],\n",
       "        [0.84815],\n",
       "        [0.8496 ],\n",
       "        [0.84925]], dtype=float32),\n",
       " 'imagenet_wukong_embedding': array([[0.920125],\n",
       "        [0.914625],\n",
       "        [0.92035 ],\n",
       "        [0.917275],\n",
       "        [0.91415 ]], dtype=float32),\n",
       " 'imagenet_adm_embedding': array([[0.85505 ],\n",
       "        [0.855725],\n",
       "        [0.854125],\n",
       "        [0.8574  ],\n",
       "        [0.849125]], dtype=float32),\n",
       " 'imagenet_midjourney_embedding': array([[0.94325 ],\n",
       "        [0.9454  ],\n",
       "        [0.9513  ],\n",
       "        [0.94335 ],\n",
       "        [0.940975]], dtype=float32)}"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_c2st_nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_c2st_rf = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting  imagenet_unconditional_model_embedding\n",
      "Starting  imagenet_cs1_embedding\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[50], line 14\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     13\u001b[0m         data_syn \u001b[38;5;241m=\u001b[39m get_dataset(dname)(\u001b[38;5;241m20_000\u001b[39m, \u001b[38;5;241m2048\u001b[39m, permute\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m---> 14\u001b[0m     m \u001b[38;5;241m=\u001b[39m \u001b[43mmetric_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_syn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     15\u001b[0m     metric_values\u001b[38;5;241m.\u001b[39mappend(m)\n\u001b[1;32m     16\u001b[0m results_c2st_rf[dname] \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(metric_values)\n",
      "File \u001b[0;32m/mnt/c/Users/manug/OneDrive/Desktop_backup/labproject/labproject/labproject/metrics/utils.py:24\u001b[0m, in \u001b[0;36mregister_metric.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m     22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     23\u001b[0m     \u001b[38;5;66;03m# Call the original function\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m     metric \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     26\u001b[0m     \u001b[38;5;66;03m# Convert output to tensor\u001b[39;00m\n\u001b[1;32m     27\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(metric, torch\u001b[38;5;241m.\u001b[39mTensor):\n",
      "File \u001b[0;32m/mnt/c/Users/manug/OneDrive/Desktop_backup/labproject/labproject/labproject/metrics/c2st.py:168\u001b[0m, in \u001b[0;36mc2st_rf\u001b[0;34m(X, Y, seed, n_folds, metric, z_score, n_estimators, clf_kwargs)\u001b[0m\n\u001b[1;32m    165\u001b[0m clf_class \u001b[38;5;241m=\u001b[39m RandomForestClassifier\n\u001b[1;32m    166\u001b[0m clf_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_estimators\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m n_estimators\n\u001b[0;32m--> 168\u001b[0m scores_ \u001b[38;5;241m=\u001b[39m \u001b[43mc2st_scores\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    169\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    170\u001b[0m \u001b[43m    \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    171\u001b[0m \u001b[43m    \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    172\u001b[0m \u001b[43m    \u001b[49m\u001b[43mn_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_folds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    173\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmetric\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    174\u001b[0m \u001b[43m    \u001b[49m\u001b[43mz_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mz_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    175\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnoise_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    176\u001b[0m \u001b[43m    \u001b[49m\u001b[43mverbosity\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    177\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclf_class\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclf_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    178\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclf_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclf_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    179\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    181\u001b[0m scores \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(scores_)\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m    182\u001b[0m value \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(np\u001b[38;5;241m.\u001b[39matleast_1d(scores))\n",
      "File \u001b[0;32m/mnt/c/Users/manug/OneDrive/Desktop_backup/labproject/labproject/labproject/metrics/c2st.py:357\u001b[0m, in \u001b[0;36mc2st_scores\u001b[0;34m(X, Y, seed, n_folds, metric, z_score, noise_scale, verbosity, clf_class, clf_kwargs)\u001b[0m\n\u001b[1;32m    354\u001b[0m target \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate((np\u001b[38;5;241m.\u001b[39mzeros((X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],)), np\u001b[38;5;241m.\u001b[39mones((Y\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],))))\n\u001b[1;32m    356\u001b[0m shuffle \u001b[38;5;241m=\u001b[39m KFold(n_splits\u001b[38;5;241m=\u001b[39mn_folds, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, random_state\u001b[38;5;241m=\u001b[39mseed)\n\u001b[0;32m--> 357\u001b[0m scores \u001b[38;5;241m=\u001b[39m \u001b[43mcross_val_score\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshuffle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscoring\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbosity\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    359\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m scores\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    208\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m    209\u001b[0m         skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m    210\u001b[0m             prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m    211\u001b[0m         )\n\u001b[1;32m    212\u001b[0m     ):\n\u001b[0;32m--> 213\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    215\u001b[0m     \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m    217\u001b[0m     \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m    218\u001b[0m     \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m    219\u001b[0m     msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m    220\u001b[0m         \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    221\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    222\u001b[0m         \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m    223\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/model_selection/_validation.py:714\u001b[0m, in \u001b[0;36mcross_val_score\u001b[0;34m(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, params, pre_dispatch, error_score)\u001b[0m\n\u001b[1;32m    711\u001b[0m \u001b[38;5;66;03m# To ensure multimetric format is not supported\u001b[39;00m\n\u001b[1;32m    712\u001b[0m scorer \u001b[38;5;241m=\u001b[39m check_scoring(estimator, scoring\u001b[38;5;241m=\u001b[39mscoring)\n\u001b[0;32m--> 714\u001b[0m cv_results \u001b[38;5;241m=\u001b[39m \u001b[43mcross_validate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    715\u001b[0m \u001b[43m    \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    716\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    717\u001b[0m \u001b[43m    \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    718\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgroups\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    719\u001b[0m \u001b[43m    \u001b[49m\u001b[43mscoring\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mscore\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mscorer\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    720\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcv\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    721\u001b[0m \u001b[43m    \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    722\u001b[0m \u001b[43m    \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    723\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfit_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfit_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    724\u001b[0m \u001b[43m    \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    725\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    726\u001b[0m \u001b[43m    \u001b[49m\u001b[43merror_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    727\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    728\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cv_results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_score\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    208\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m    209\u001b[0m         skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m    210\u001b[0m             prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m    211\u001b[0m         )\n\u001b[1;32m    212\u001b[0m     ):\n\u001b[0;32m--> 213\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    215\u001b[0m     \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m    217\u001b[0m     \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m    218\u001b[0m     \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m    219\u001b[0m     msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m    220\u001b[0m         \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    221\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    222\u001b[0m         \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m    223\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/model_selection/_validation.py:425\u001b[0m, in \u001b[0;36mcross_validate\u001b[0;34m(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, params, pre_dispatch, return_train_score, return_estimator, return_indices, error_score)\u001b[0m\n\u001b[1;32m    422\u001b[0m \u001b[38;5;66;03m# We clone the estimator to make sure that all the folds are\u001b[39;00m\n\u001b[1;32m    423\u001b[0m \u001b[38;5;66;03m# independent, and that it is pickle-able.\u001b[39;00m\n\u001b[1;32m    424\u001b[0m parallel \u001b[38;5;241m=\u001b[39m Parallel(n_jobs\u001b[38;5;241m=\u001b[39mn_jobs, verbose\u001b[38;5;241m=\u001b[39mverbose, pre_dispatch\u001b[38;5;241m=\u001b[39mpre_dispatch)\n\u001b[0;32m--> 425\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mparallel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    426\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_fit_and_score\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    427\u001b[0m \u001b[43m        \u001b[49m\u001b[43mclone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    428\u001b[0m \u001b[43m        \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    429\u001b[0m \u001b[43m        \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    430\u001b[0m \u001b[43m        \u001b[49m\u001b[43mscorer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscorers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    431\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    432\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtest\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    433\u001b[0m \u001b[43m        \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    434\u001b[0m \u001b[43m        \u001b[49m\u001b[43mparameters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    435\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfit_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrouted_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    436\u001b[0m \u001b[43m        \u001b[49m\u001b[43mscore_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrouted_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscorer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    437\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_train_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_train_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    438\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_times\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    439\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_estimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_estimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    440\u001b[0m \u001b[43m        \u001b[49m\u001b[43merror_score\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    441\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    442\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\n\u001b[1;32m    443\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    445\u001b[0m _warn_or_raise_about_fit_failures(results, error_score)\n\u001b[1;32m    447\u001b[0m \u001b[38;5;66;03m# For callable scoring, the return type is only know after calling. If the\u001b[39;00m\n\u001b[1;32m    448\u001b[0m \u001b[38;5;66;03m# return type is a dictionary, the error scores can now be inserted with\u001b[39;00m\n\u001b[1;32m    449\u001b[0m \u001b[38;5;66;03m# the correct key.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:67\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m     62\u001b[0m config \u001b[38;5;241m=\u001b[39m get_config()\n\u001b[1;32m     63\u001b[0m iterable_with_config \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m     64\u001b[0m     (_with_config(delayed_func, config), args, kwargs)\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m delayed_func, args, kwargs \u001b[38;5;129;01min\u001b[39;00m iterable\n\u001b[1;32m     66\u001b[0m )\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43miterable_with_config\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1863\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1861\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_sequential_output(iterable)\n\u001b[1;32m   1862\u001b[0m     \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 1863\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1865\u001b[0m \u001b[38;5;66;03m# Let's create an ID that uniquely identifies the current call. If the\u001b[39;00m\n\u001b[1;32m   1866\u001b[0m \u001b[38;5;66;03m# call is interrupted early and that the same instance is immediately\u001b[39;00m\n\u001b[1;32m   1867\u001b[0m \u001b[38;5;66;03m# re-used, this id will be used to prevent workers that were\u001b[39;00m\n\u001b[1;32m   1868\u001b[0m \u001b[38;5;66;03m# concurrently finalizing a task from the previous call to run the\u001b[39;00m\n\u001b[1;32m   1869\u001b[0m \u001b[38;5;66;03m# callback.\u001b[39;00m\n\u001b[1;32m   1870\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1792\u001b[0m, in \u001b[0;36mParallel._get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1790\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_batches \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1791\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1792\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1793\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_completed_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1794\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_progress()\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:129\u001b[0m, in \u001b[0;36m_FuncWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    127\u001b[0m     config \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig):\n\u001b[0;32m--> 129\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/model_selection/_validation.py:890\u001b[0m, in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, score_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m    888\u001b[0m         estimator\u001b[38;5;241m.\u001b[39mfit(X_train, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfit_params)\n\u001b[1;32m    889\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 890\u001b[0m         \u001b[43mestimator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfit_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    892\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m    893\u001b[0m     \u001b[38;5;66;03m# Note fit time as time until error\u001b[39;00m\n\u001b[1;32m    894\u001b[0m     fit_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m start_time\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/base.py:1351\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1344\u001b[0m     estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m   1346\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m   1347\u001b[0m     skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m   1348\u001b[0m         prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m   1349\u001b[0m     )\n\u001b[1;32m   1350\u001b[0m ):\n\u001b[0;32m-> 1351\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/ensemble/_forest.py:489\u001b[0m, in \u001b[0;36mBaseForest.fit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    478\u001b[0m trees \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m    479\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_estimator(append\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, random_state\u001b[38;5;241m=\u001b[39mrandom_state)\n\u001b[1;32m    480\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_more_estimators)\n\u001b[1;32m    481\u001b[0m ]\n\u001b[1;32m    483\u001b[0m \u001b[38;5;66;03m# Parallel loop: we prefer the threading backend as the Cython code\u001b[39;00m\n\u001b[1;32m    484\u001b[0m \u001b[38;5;66;03m# for fitting the trees is internally releasing the Python GIL\u001b[39;00m\n\u001b[1;32m    485\u001b[0m \u001b[38;5;66;03m# making threading more efficient than multiprocessing in\u001b[39;00m\n\u001b[1;32m    486\u001b[0m \u001b[38;5;66;03m# that case. However, for joblib 0.12+ we respect any\u001b[39;00m\n\u001b[1;32m    487\u001b[0m \u001b[38;5;66;03m# parallel_backend contexts set at a higher level,\u001b[39;00m\n\u001b[1;32m    488\u001b[0m \u001b[38;5;66;03m# since correctness does not rely on using threads.\u001b[39;00m\n\u001b[0;32m--> 489\u001b[0m trees \u001b[38;5;241m=\u001b[39m \u001b[43mParallel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    490\u001b[0m \u001b[43m    \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    491\u001b[0m \u001b[43m    \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    492\u001b[0m \u001b[43m    \u001b[49m\u001b[43mprefer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mthreads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    493\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    494\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_parallel_build_trees\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    495\u001b[0m \u001b[43m        \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    496\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbootstrap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    497\u001b[0m \u001b[43m        \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    498\u001b[0m \u001b[43m        \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    499\u001b[0m \u001b[43m        \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    500\u001b[0m \u001b[43m        \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    501\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrees\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    502\u001b[0m \u001b[43m        \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    503\u001b[0m \u001b[43m        \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    504\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_samples_bootstrap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_samples_bootstrap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    505\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    506\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    507\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrees\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    508\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    510\u001b[0m \u001b[38;5;66;03m# Collect newly grown trees\u001b[39;00m\n\u001b[1;32m    511\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mestimators_\u001b[38;5;241m.\u001b[39mextend(trees)\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:67\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m     62\u001b[0m config \u001b[38;5;241m=\u001b[39m get_config()\n\u001b[1;32m     63\u001b[0m iterable_with_config \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m     64\u001b[0m     (_with_config(delayed_func, config), args, kwargs)\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m delayed_func, args, kwargs \u001b[38;5;129;01min\u001b[39;00m iterable\n\u001b[1;32m     66\u001b[0m )\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43miterable_with_config\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1863\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1861\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_sequential_output(iterable)\n\u001b[1;32m   1862\u001b[0m     \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 1863\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1865\u001b[0m \u001b[38;5;66;03m# Let's create an ID that uniquely identifies the current call. If the\u001b[39;00m\n\u001b[1;32m   1866\u001b[0m \u001b[38;5;66;03m# call is interrupted early and that the same instance is immediately\u001b[39;00m\n\u001b[1;32m   1867\u001b[0m \u001b[38;5;66;03m# re-used, this id will be used to prevent workers that were\u001b[39;00m\n\u001b[1;32m   1868\u001b[0m \u001b[38;5;66;03m# concurrently finalizing a task from the previous call to run the\u001b[39;00m\n\u001b[1;32m   1869\u001b[0m \u001b[38;5;66;03m# callback.\u001b[39;00m\n\u001b[1;32m   1870\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/joblib/parallel.py:1792\u001b[0m, in \u001b[0;36mParallel._get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1790\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_batches \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1791\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_dispatched_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1792\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1793\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_completed_tasks \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1794\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_progress()\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/utils/parallel.py:129\u001b[0m, in \u001b[0;36m_FuncWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    127\u001b[0m     config \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig):\n\u001b[0;32m--> 129\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/ensemble/_forest.py:192\u001b[0m, in \u001b[0;36m_parallel_build_trees\u001b[0;34m(tree, bootstrap, X, y, sample_weight, tree_idx, n_trees, verbose, class_weight, n_samples_bootstrap, missing_values_in_feature_mask)\u001b[0m\n\u001b[1;32m    189\u001b[0m     \u001b[38;5;28;01melif\u001b[39;00m class_weight \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbalanced_subsample\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    190\u001b[0m         curr_sample_weight \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m=\u001b[39m compute_sample_weight(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbalanced\u001b[39m\u001b[38;5;124m\"\u001b[39m, y, indices\u001b[38;5;241m=\u001b[39mindices)\n\u001b[0;32m--> 192\u001b[0m     \u001b[43mtree\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    193\u001b[0m \u001b[43m        \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    194\u001b[0m \u001b[43m        \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    195\u001b[0m \u001b[43m        \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurr_sample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    196\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcheck_input\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    197\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    198\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    199\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    200\u001b[0m     tree\u001b[38;5;241m.\u001b[39m_fit(\n\u001b[1;32m    201\u001b[0m         X,\n\u001b[1;32m    202\u001b[0m         y,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    205\u001b[0m         missing_values_in_feature_mask\u001b[38;5;241m=\u001b[39mmissing_values_in_feature_mask,\n\u001b[1;32m    206\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/labproject/lib/python3.9/site-packages/sklearn/tree/_classes.py:472\u001b[0m, in \u001b[0;36mBaseDecisionTree._fit\u001b[0;34m(self, X, y, sample_weight, check_input, missing_values_in_feature_mask)\u001b[0m\n\u001b[1;32m    461\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    462\u001b[0m     builder \u001b[38;5;241m=\u001b[39m BestFirstTreeBuilder(\n\u001b[1;32m    463\u001b[0m         splitter,\n\u001b[1;32m    464\u001b[0m         min_samples_split,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    469\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmin_impurity_decrease,\n\u001b[1;32m    470\u001b[0m     )\n\u001b[0;32m--> 472\u001b[0m \u001b[43mbuilder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuild\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtree_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmissing_values_in_feature_mask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    474\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_outputs_ \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m is_classifier(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    475\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_classes_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_classes_[\u001b[38;5;241m0\u001b[39m]\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "metric = \"c2st_rf\"\n",
    "metric_fn = get_metric(metric)\n",
    "for dname in datasets:\n",
    "    print(\"Starting \", dname)\n",
    "    metric_values = []\n",
    "    for j in range(5):\n",
    "        data_test = testset[j*20_000:(j+1)*20_000]\n",
    "        if dname == \"imagenet_midjourney_embedding\":\n",
    "            data_syn = get_dataset(dname)(10_000, 2048, permute=False)\n",
    "            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]\n",
    "        else:\n",
    "            data_syn = get_dataset(dname)(20_000, 2048, permute=True)\n",
    "        m = metric_fn(data_test, data_syn, n_folds=2)\n",
    "        metric_values.append(m)\n",
    "    results_c2st_rf[dname] = np.array(metric_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"results_c2st_rf.npy\", results_c2st_rf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_c2st_rf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9410])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric_fn = get_metric(\"c2st_nn\")\n",
    "m = metric_fn(data_test, data_syn, n_folds=2)\n",
    "m"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labproject",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
