{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "58116340-2165-45ce-9d93-1cc30e1ea824",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from metrics import dist_cov, dist_cov_, dist_var, dist_var_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a3dbce96-6699-41ae-ba1f-d702e57dfbc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = np.zeros((20, 30, 10))\n",
    "t2 = np.zeros((20, 30, 10))\n",
    "assert dist_cov(t1, t2) == 0\n",
    "assert dist_var(t1) == 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eab3ffb7-1ba4-476c-9ccc-8789ccea1da8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# very slow but as easy as it gets\n",
    "# we imply rbf kernel with gamma=1\n",
    "def unit_test_var(tensor):\n",
    "    n, m, _ = tensor.shape\n",
    "    n_inner = 0\n",
    "    n_outer = 0\n",
    "    inner = 0\n",
    "    outer = 0\n",
    "    for i in range(n):\n",
    "        for j in range(m):\n",
    "            for s in range(n):\n",
    "                for t in range(m):\n",
    "                    # within clusters\n",
    "                    if (i==s) & (j!=t):\n",
    "                        inner += np.exp(-(tensor[i,j,0]-tensor[s,t,0])**2)\n",
    "                        n_inner += 1\n",
    "                    # between clusters\n",
    "                    if i!=s:\n",
    "                        outer += np.exp(-(tensor[i,j,0]-tensor[s,t,0])**2)\n",
    "                        n_outer += 1\n",
    "    # avg within clusters minus avg between clusters\n",
    "    return (inner / n_inner) - (outer / n_outer)\n",
    "\n",
    "def unit_test_cov(tensor1, tensor2):\n",
    "    assert tensor1.shape == tensor2.shape\n",
    "    n, m, _ = tensor1.shape\n",
    "    n_inner = 0\n",
    "    n_outer = 0\n",
    "    inner = 0\n",
    "    outer = 0\n",
    "    for i in range(n):\n",
    "        for j in range(m):\n",
    "            for s in range(n):\n",
    "                for t in range(m):\n",
    "                    # within clusters\n",
    "                    if i==s:\n",
    "                        inner += np.exp(-(tensor1[i,j,0]-tensor2[s,t,0])**2)\n",
    "                        n_inner += 1\n",
    "                    # between clusters\n",
    "                    if i!=s:\n",
    "                        outer += np.exp(-(tensor1[i,j,0]-tensor2[s,t,0])**2)\n",
    "                        n_outer += 1\n",
    "    # avg within clusters minus avg between clusters\n",
    "    return (inner / n_inner) - (outer / n_outer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4bac7330-316b-4bbe-a6df-ea69983bfc8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[0]\n",
      "  [1]]\n",
      "\n",
      " [[2]\n",
      "  [3]]]\n",
      "(2, 2, 1)\n"
     ]
    }
   ],
   "source": [
    "t3 = np.arange(0, 4).reshape(2, 2, 1)\n",
    "t4 = np.arange(5, 9).reshape(2, 2, 1)\n",
    "print(t3)\n",
    "print(t3.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "073ed786-8dda-466e-8c74-db7ba003ea81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.266720908983193\n",
      "-0.0023202932382726015\n"
     ]
    }
   ],
   "source": [
    "print(unit_test_var(t3))\n",
    "print(unit_test_cov(t3, t4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2ac1d7a9-df40-498d-bd42-09355d1b08d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.266720908983193\n",
      "-0.002320293238272602\n"
     ]
    }
   ],
   "source": [
    "print(dist_var(t3))\n",
    "print(dist_cov(t3, t4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ee0e5cef-edd9-4528-adee-acb3decd3462",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.266720908983193\n",
      "-0.0023202932382726015\n"
     ]
    }
   ],
   "source": [
    "print(dist_var_(t3))\n",
    "print(dist_cov_(t3, t4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4b9c5b03-b41d-4eed-a914-8e6a1f63ded9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[ 0]\n",
      "  [ 1]\n",
      "  [ 2]]\n",
      "\n",
      " [[ 3]\n",
      "  [ 4]\n",
      "  [ 5]]\n",
      "\n",
      " [[ 6]\n",
      "  [ 7]\n",
      "  [ 8]]\n",
      "\n",
      " [[ 9]\n",
      "  [10]\n",
      "  [11]]]\n",
      "(4, 3, 1)\n"
     ]
    }
   ],
   "source": [
    "t3 = np.arange(0, 12).reshape(4, 3, 1)\n",
    "t4 = np.arange(12, 24).reshape(4, 3, 1)\n",
    "print(t3)\n",
    "print(t3.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6be90df9-663b-4161-992f-f923e68b8e1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2288647710501299\n",
      "-0.0037489018386232475\n"
     ]
    }
   ],
   "source": [
    "print(unit_test_var(t3))\n",
    "print(unit_test_cov(t3, t4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "77c59ec5-f970-43f8-9f39-7f8f8341bd6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2288647710501299\n",
      "-0.0037489018386232475\n"
     ]
    }
   ],
   "source": [
    "print(dist_var(t3))\n",
    "print(dist_cov(t3, t4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ec694c05-a681-4e27-a4c5-2b98aad3fce9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2288647710501299\n",
      "-0.003748901838623247\n"
     ]
    }
   ],
   "source": [
    "print(dist_var_(t3))\n",
    "print(dist_cov_(t3, t4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6fdfc805-e8d9-411d-a6c9-39f6128ea49f",
   "metadata": {},
   "outputs": [],
   "source": [
    "### RUNTIME COMPARISON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a92bdd6e-6a6d-412d-b73e-9aa15436986f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 22.2 s, sys: 58.6 ms, total: 22.2 s\n",
      "Wall time: 22.3 s\n"
     ]
    }
   ],
   "source": [
    "# ~22s\n",
    "%%time\n",
    "rng = np.random.default_rng(0)\n",
    "for _ in range(100):\n",
    "    t3 = rng.random(400).reshape(20, 20, 1)\n",
    "    t4 = rng.random(400).reshape(20, 20, 1)\n",
    "\n",
    "    _ = unit_test_var(t3)\n",
    "    _ = unit_test_cov(t3, t4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c0cc43a7-8ef7-4b67-9ad6-54e824fd924b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 227 ms, sys: 4.47 ms, total: 231 ms\n",
      "Wall time: 231 ms\n"
     ]
    }
   ],
   "source": [
    "# ~0.23s\n",
    "%%time\n",
    "rng = np.random.default_rng(0)\n",
    "for _ in range(100):\n",
    "    t3 = rng.random(400).reshape(20, 20, 1)\n",
    "    t4 = rng.random(400).reshape(20, 20, 1)\n",
    "\n",
    "    _ = dist_var_(t3)\n",
    "    _ = dist_cov_(t3, t4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "49e0556e-a24a-4426-a260-661dc6bcd3f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 185 ms, sys: 5.09 ms, total: 190 ms\n",
      "Wall time: 189 ms\n"
     ]
    }
   ],
   "source": [
    "# ~0.19s\n",
    "%%time\n",
    "rng = np.random.default_rng(0)\n",
    "for _ in range(100):\n",
    "    t3 = rng.random(400).reshape(20, 20, 1)\n",
    "    t4 = rng.random(400).reshape(20, 20, 1)\n",
    "\n",
    "    _ = dist_var(t3)\n",
    "    _ = dist_cov(t3, t4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b23b8e8-1be8-40c8-997e-4f5ea10aacc1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
