{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e8ab445f-d7c3-454f-92be-381f23f4d5fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkotelevskii/github/uncertainty_from_proper_scoring_rules/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from source.metrics import get_risk_approximation\n",
    "from source.source import vectorizer_uncertainty_scores\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "08f4aa2f-dc4a-4b6e-bd30-66a383b43065",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_ensemble = 20\n",
    "N_optim_searches = 1000000\n",
    "n_classes = 2\n",
    "max_concentration = 5\n",
    "\n",
    "n_exp = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "90cd2c8b-cd0f-43e7-abbc-0f312f56c227",
   "metadata": {},
   "outputs": [],
   "source": [
    "def central_prediction_spherical(\n",
    "    logits_: np.ndarray,\n",
    "):\n",
    "    probs = safe_softmax(logits_)\n",
    "    K = logits_.shape[-1]\n",
    "\n",
    "    norms = np.linalg.norm(probs, axis=-1, keepdims=True, ord=2)\n",
    "    x = np.mean(probs / norms, axis=0, keepdims=True)\n",
    "\n",
    "    x0 = np.ones(K).reshape(1, 1, K) / K\n",
    "    x0_norm = np.linalg.norm(x0, ord=2, axis=-1).ravel()\n",
    "\n",
    "    y_orth = x - np.sum(x * x0, axis=-1, keepdims=True) * (x0 / x0_norm**2)\n",
    "    y_orth_norm = np.linalg.norm(y_orth, ord=2, axis=-1).ravel()\n",
    "\n",
    "    central_pred = x0 + (y_orth / np.sqrt(1 - y_orth_norm**2)) * x0_norm\n",
    "    return central_pred\n",
    "\n",
    "\n",
    "def central_prediction_logscore(\n",
    "    logits_: np.ndarray,\n",
    "):\n",
    "    # probs = safe_softmax(logits_)\n",
    "    return safe_softmax(np.mean(logits_, axis=0, keepdims=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e70e59d2-a5a1-431e-8071-cdf148eacc8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "FUNC_MINIMIZE = pairwise_prob_diff  # pairwise_brier pairwise_kl pairwise_prob_diff pairwise_spherical\n",
    "THEORETICAL_CENTRAL_PREDICTION = central_prediction_neglog  # central_prediction_neglog central_prediction_brier central_prediction_logscore central_prediction_maxprob central_prediction_spherical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2e4e10a8-c513-49f0-94a0-181ce49b2560",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:45<00:00, 45.40s/it]\n"
     ]
    }
   ],
   "source": [
    "l1_err = []\n",
    "\n",
    "for exp_id in tqdm(range(n_exp)):\n",
    "    # sampled_vectors = np.random.dirichlet(\n",
    "    #     alpha=np.random.rand(n_classes) * max_concentration,\n",
    "    #     size=N_ensemble\n",
    "    # ).reshape(N_ensemble, 1, n_classes)\n",
    "    sampled_vectors = np.array(\n",
    "        [\n",
    "            [[0.89914033, 0.10085967]],\n",
    "            [[0.4389841, 0.5610159]],\n",
    "            [[0.99207102, 0.00792898]],\n",
    "            [[0.57802041, 0.42197959]],\n",
    "            [[0.89409964, 0.10590036]],\n",
    "            [[0.60881584, 0.39118416]],\n",
    "            [[0.9533906, 0.0466094]],\n",
    "            [[0.91530455, 0.08469545]],\n",
    "            [[0.24826006, 0.75173994]],\n",
    "            [[0.36068223, 0.63931777]],\n",
    "            [[0.60374964, 0.39625036]],\n",
    "            [[0.96204757, 0.03795243]],\n",
    "            [[0.99746462, 0.00253538]],\n",
    "            [[0.51398327, 0.48601673]],\n",
    "            [[0.88041703, 0.11958297]],\n",
    "            [[0.44736386, 0.55263614]],\n",
    "            [[0.44013821, 0.55986179]],\n",
    "            [[0.77684085, 0.22315915]],\n",
    "            [[0.40832816, 0.59167184]],\n",
    "            [[0.21596403, 0.78403597]],\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    loss = []\n",
    "    p_vals = []\n",
    "\n",
    "    for rep in range(N_optim_searches + 1):\n",
    "        if rep == N_optim_searches:\n",
    "            central_prediction = np.array([0.99577723, 0.00422277]).reshape(\n",
    "                1, 1, n_classes\n",
    "            )\n",
    "        central_prediction = np.random.dirichlet(\n",
    "            alpha=max_concentration * np.ones(n_classes), size=1\n",
    "        ).reshape(1, 1, n_classes)\n",
    "        loss.append(\n",
    "            FUNC_MINIMIZE(np.log(central_prediction), np.log(sampled_vectors)).mean()\n",
    "        )\n",
    "        p_vals.append(central_prediction.ravel())\n",
    "\n",
    "    central_prediction_theory = THEORETICAL_CENTRAL_PREDICTION(np.log(sampled_vectors))\n",
    "    p_vals = np.vstack(p_vals)\n",
    "\n",
    "    err = np.linalg.norm(\n",
    "        np.abs(p_vals[np.argmin(loss)].ravel() - central_prediction_theory.ravel()),\n",
    "        ord=1,\n",
    "        axis=-1,\n",
    "    )\n",
    "\n",
    "    l1_err.append(err)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4d86ae-9f8c-404a-93ba-ffba339c143c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a34eec66-4c6c-4d8e-8528-2f12fd7bf7fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.43357605, 0.00128463])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.abs(p_vals[np.argmin(loss)].ravel() - central_prediction_theory.ravel())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1ec71f7-628b-49d8-bd1b-fdb78aca80de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7def1a06-a4c8-4f39-a1fa-0c40b4f5d1fb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "08d95697-befb-4f88-8510-7bef4dc0196f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.53354849, 0.03159083])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "central_prediction_theory.ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e84803e8-2ddd-43ee-b510-7ec649301d97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.96712454, 0.03287546])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_vals[np.argmin(loss)].ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ac2d541f-c8d9-4bf2-a2d3-e1fdaec0729f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.02508621, 0.97491379])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_vals[np.argmax(loss)].ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3d91b50-7e24-4b8e-ac5a-5fe2eed6ae83",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be3a9ce9-1aa4-49a0-ad17-99fdbd07d58f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bd968fc8-fd6b-49b2-80bf-d4a9665d4b70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.06658520691916528"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(l1_err)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e852b298-26e9-4c6a-acc1-953ba2387841",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "740dafcc-174b-43dd-ab93-b302508c6365",
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_vectors = np.array(\n",
    "    [\n",
    "        [[0.89914033, 0.10085967]],\n",
    "        [[0.4389841, 0.5610159]],\n",
    "        [[0.99207102, 0.00792898]],\n",
    "        [[0.57802041, 0.42197959]],\n",
    "        [[0.89409964, 0.10590036]],\n",
    "        [[0.60881584, 0.39118416]],\n",
    "        [[0.9533906, 0.0466094]],\n",
    "        [[0.91530455, 0.08469545]],\n",
    "        [[0.24826006, 0.75173994]],\n",
    "        [[0.36068223, 0.63931777]],\n",
    "        [[0.60374964, 0.39625036]],\n",
    "        [[0.96204757, 0.03795243]],\n",
    "        [[0.99746462, 0.00253538]],\n",
    "        [[0.51398327, 0.48601673]],\n",
    "        [[0.88041703, 0.11958297]],\n",
    "        [[0.44736386, 0.55263614]],\n",
    "        [[0.44013821, 0.55986179]],\n",
    "        [[0.77684085, 0.22315915]],\n",
    "        [[0.40832816, 0.59167184]],\n",
    "        [[0.21596403, 0.78403597]],\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d3a0cf6a-84a3-416a-98ff-5464c6004251",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = sampled_vectors.shape[-1]\n",
    "x_0 = np.ones(dim) / dim\n",
    "x = np.mean(1 / sampled_vectors, axis=0)\n",
    "x_0 = x_0.reshape(*x.shape)\n",
    "x_parallel = x_0 * np.sum(x_0 * x) / np.linalg.norm(x_0, ord=2) ** 2\n",
    "x_perp = x - x_parallel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "61f027e0-e700-44ef-b190-01dd6caf5ef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = (dim - np.sum(x_parallel * x_0, axis=-1)) / np.linalg.norm(x_perp) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6694f00a-dad2-457e-9a16-4c2431a0475f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.03329541647082493"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "k[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "08a0b749-5bfc-4378-8875-58161375a61e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.8742439, 31.6547535]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "be3a8424-2d19-45da-9efb-69a886b1c726",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = x_0 + k[0] * x_perp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "91055ce1-17a8-4725-9e68-16975ce9f18b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.99577723, 0.00422277]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d72016ed-b5fe-4ca5-9e51-b2f930fb2e12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9999999999999997"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "88b8ed8d-71eb-4840-8cf8-636f575a641b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.8742439016279302"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "753ee1a1-1f01-4326-a30b-d370efbd2433",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = x.ravel()[1] - x.ravel()[0]\n",
    "answ_anal = (A - 2 + np.sqrt((A - 2) ** 2 + 4 * A)) / (2 * A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "83f262fc-3a80-4b3c-abbf-1c3758c6bb49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.96754727, 0.03245273])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z_anal = np.array([answ_anal, 1 - answ_anal])\n",
    "z_anal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fd2362b-0a17-4d0e-a9db-e54b1afa3205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
