{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "11f26523",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "id": "8b6ca760",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = 10\n",
    "x = np.random.randn(100)\n",
    "# x = np.zeros(1000)\n",
    "x[0] = 10.1\n",
    "x[1] = 1000\n",
    "x[-1] = 1000\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "id": "d796655e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_probs(x, v):\n",
    "    x_abs = np.abs(x)\n",
    "    x_square = x ** 2\n",
    "    index_sorted = np.argsort(x_abs)\n",
    "    index_reverse = np.argsort(index_sorted)\n",
    "    sqrt_lmbda = np.sum(x_abs) / v\n",
    "    if np.all(x_abs <= sqrt_lmbda):\n",
    "        return x_abs / sqrt_lmbda\n",
    "    sum_ = 0\n",
    "    x_abs_sorted = x_abs[index_sorted]\n",
    "    opt_i = None\n",
    "    for i in range(len(x)):\n",
    "        other = len(x) - i - 1\n",
    "        sum_ += x_abs_sorted[i]\n",
    "        if other >= v:\n",
    "            continue\n",
    "        v_other = v - other\n",
    "        sqrt_lmbda = sum_ / v_other + 1e-8\n",
    "        if np.all(x_abs_sorted[:i + 1] <= sqrt_lmbda):\n",
    "            probs = np.zeros(len(x))\n",
    "            probs[:i + 1] = x_abs_sorted[:i + 1] / sqrt_lmbda\n",
    "            probs[i + 1:] = 1\n",
    "            opt_i = i\n",
    "#             print(other, sqrt_lmbda)\n",
    "            print(np.sum(x_abs_sorted[x_abs_sorted > 0]**2 / probs[x_abs_sorted > 0]))\n",
    "    assert np.sum(probs) <= v\n",
    "    return probs[index_reverse], opt_i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "id": "da8e48fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2003714.0102054628\n",
      "2002029.755395695\n",
      "2001468.597228371\n",
      "2001189.046866165\n",
      "2001022.7780897703\n",
      "2000912.1282077944\n",
      "2000842.1480331423\n",
      "2000842.1188250377\n",
      "(array([0.9844183 , 1.        , 0.02076524, 0.0134656 , 0.24735551,\n",
      "       0.08389112, 0.02300243, 0.13729014, 0.04902486, 0.13670462,\n",
      "       0.09187896, 0.05650753, 0.00700473, 0.0362856 , 0.01744885,\n",
      "       0.08081088, 0.05030668, 0.05086816, 0.04262411, 0.08154001,\n",
      "       0.13211107, 0.08969314, 0.0491994 , 0.19210815, 0.02248746,\n",
      "       0.09523998, 0.10283591, 0.04852976, 0.12574015, 0.01279422,\n",
      "       0.03400742, 0.11692109, 0.04935391, 0.07842727, 0.1753571 ,\n",
      "       0.01923461, 0.12662494, 0.0151631 , 0.15848029, 0.00235931,\n",
      "       0.18149309, 0.0190033 , 0.0528707 , 0.00370798, 0.06773267,\n",
      "       0.15328536, 0.17739993, 0.00925533, 0.01656385, 0.02114742,\n",
      "       0.1040738 , 0.06285828, 0.01683938, 0.14662129, 0.09657813,\n",
      "       0.08282011, 0.03191592, 0.10828833, 0.01900837, 0.08401791,\n",
      "       0.02572145, 0.02373495, 0.08262872, 0.05260352, 0.18495222,\n",
      "       0.07652022, 0.08233632, 0.11611195, 0.06390281, 0.07389338,\n",
      "       0.08777065, 0.01087981, 0.19111403, 0.1017263 , 0.00707909,\n",
      "       0.06487539, 0.03231514, 0.06982014, 0.06490983, 0.08672328,\n",
      "       0.01099003, 0.02421875, 0.04608427, 0.06148567, 0.06265767,\n",
      "       0.18083782, 0.07437587, 0.01397866, 0.08858224, 0.01086081,\n",
      "       0.00499427, 0.07194954, 0.02786735, 0.15453467, 0.11699098,\n",
      "       0.16037613, 0.01116692, 0.0226359 , 0.04245055, 1.        ]), 97)\n"
     ]
    }
   ],
   "source": [
    "print(get_probs(x, v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c68584fe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
