{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "22cfdcec",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ed222bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.random.seed(2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e26824e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# index\n",
    "n = 10\n",
    "l = np.arange(n + 1)\n",
    "i = np.stack(np.meshgrid(l, l, l)).reshape(3, -1)\n",
    "\n",
    "# factor\n",
    "y = i / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0dbd0522",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.transform import Rotation\n",
    "r = Rotation.random()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1b8fe8ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "min [0. 0. 0.] max [1. 1. 1.]\n",
      "(3, 1331) (3, 1331)\n"
     ]
    }
   ],
   "source": [
    "# transformation\n",
    "k = lambda y: np.exp(r.apply(y.squeeze().T).T)\n",
    "t = lambda y: k(k(y))\n",
    "t_min = t(y).min(1)\n",
    "t_ptp = t(y).ptp(1)\n",
    "g = lambda y: (t(y) - t_min[:, None]) / t_ptp[:, None]\n",
    "\n",
    "# observation\n",
    "x = g(y)\n",
    "print('min', x.min(1), 'max', x.max(1))\n",
    "print(x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1f5dad64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# representation\n",
    "z_entanglement = x.copy()[..., None]\n",
    "z_rotation = r.apply(y.squeeze().T).T[..., None]\n",
    "z_duplicate = (y.T, y.T, y[2:].T)\n",
    "z_complement = (y[[1, 2]].T, y[[0, 2]].T, y[[0, 1]].T)\n",
    "z_misalignment = (y[[1]].T, y[[2]].T, y[[0]].T)\n",
    "z_redundancy = (np.stack([y[0], -y[0]], axis=1), y[1][:, None], y[2][:, None])\n",
    "z_contraction = 0.01 * y.copy()[..., None]\n",
    "z_nonlinear = y.copy()[..., None] ** 2\n",
    "z_constant = np.zeros_like(x)[..., None]\n",
    "z_random = np.random.uniform(0, 1, size=x.shape)[..., None]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4b5eeed",
   "metadata": {},
   "source": [
    "## Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "482ea8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def q_product(y: np.ndarray, z: np.ndarray, aggregate, deviation):\n",
    "    return np.sum([aggregate([deviation(zi[yi == yv]) for yv in np.unique(yi)]) for yi, zi in zip(y, z)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5b1e45f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# max of Euclidean distances\n",
    "import miniball\n",
    "\n",
    "ball = lambda z: miniball.get_bounding_ball(z + 1e-12 * np.random.randn(*z.shape))\n",
    "center = lambda z: ball(z)[0]\n",
    "radius = lambda z: np.sqrt(ball(z)[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1ef14a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sum of Euclidean distances\n",
    "from geom_median.numpy import compute_geometric_median\n",
    "\n",
    "median = lambda z: compute_geometric_median(z).median\n",
    "mean_absolute_deviation_around_median = lambda z: np.linalg.norm(z - median(z), axis=-1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "43c2bf12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sum of squared Euclidean distances\n",
    "mean = lambda z: z.mean()\n",
    "variance = lambda z: z.var(axis=0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "05d4bff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairwise_distance = lambda z: np.linalg.norm((z[:, None] - z), axis=-1)\n",
    "diameter = lambda z: pairwise_distance(z).max()\n",
    "mean_pairwise_distance = lambda z: 0.5 * pairwise_distance(z).mean()\n",
    "mean_pairwise_squared_distance = lambda z: 0.5 * (pairwise_distance(z) ** 2).mean() # variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "86ed7026",
   "metadata": {},
   "outputs": [],
   "source": [
    "fnum = lambda num: f'{np.exp(-num):4.2f}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3ef5b3b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "entanglement 0.44 0.75 0.96 0.19 0.82 \n",
      "rotation     0.22 0.51 0.80 0.05 0.64 \n",
      "duplicate    0.24 0.43 0.67 0.06 0.56 \n",
      "complement   0.12 0.28 0.55 0.01 0.42 \n",
      "misalignment 0.22 0.44 0.74 0.05 0.58 \n",
      "redundancy   1.00 1.00 1.00 1.00 1.00 \n",
      "contraction  1.00 1.00 1.00 1.00 1.00 \n",
      "nonlinear    1.00 1.00 1.00 1.00 1.00 \n",
      "constant     1.00 1.00 1.00 1.00 1.00 \n",
      "random       0.22 0.48 0.78 0.05 0.61 "
     ]
    }
   ],
   "source": [
    "for z, z_name in [\n",
    "    (z_entanglement, 'entanglement'),\n",
    "    (z_rotation, 'rotation'),\n",
    "    (z_duplicate, 'duplicate'),\n",
    "    (z_complement, 'complement'),\n",
    "    (z_misalignment, 'misalignment'),\n",
    "    (z_redundancy, 'redundancy'),\n",
    "    (z_contraction, 'contraction'),\n",
    "    (z_nonlinear, 'nonlinear'),\n",
    "    (z_constant, 'constant'),\n",
    "    (z_random, 'random'),\n",
    "]:\n",
    "    print(f'\\n{z_name:<12}', end=' ')\n",
    "\n",
    "    for aggregate, deviation in [\n",
    "        (np.max, radius),\n",
    "        (np.mean, mean_absolute_deviation_around_median),\n",
    "        (np.mean, variance),\n",
    "        (np.max, diameter),\n",
    "        (np.mean, mean_pairwise_distance),\n",
    "    ]:\n",
    "        print(fnum(q_product(y, z, aggregate, deviation)), end=' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a696ea1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
