{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_personas = 8\n",
    "\n",
    "import numpy as np\n",
    "from numba import jit, prange\n",
    "from scipy.stats import dirichlet\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "from tqdm import trange, tqdm\n",
    "\n",
    "@jit(nopython=True, parallel=True)\n",
    "def calculate_distances(persons):\n",
    "    n = len(persons)\n",
    "    distances = np.empty((n, n), dtype=np.float64)\n",
    "    for i in prange(n):\n",
    "        for j in prange(i + 1, n):\n",
    "            distances[i, j] = np.sqrt(np.sum((persons[i] - persons[j]) ** 2))\n",
    "            distances[j, i] = distances[i, j]\n",
    "    return distances\n",
    "\n",
    "@jit(nopython=True)\n",
    "def filter_similar_personas(persons, threshold=1e-5):\n",
    "    n = len(persons)\n",
    "    distances = calculate_distances(persons)\n",
    "    keep_mask = np.ones(n, dtype=np.bool_)\n",
    "\n",
    "    for i in range(n):\n",
    "        if not keep_mask[i]:\n",
    "            continue\n",
    "        for j in range(i + 1, n):\n",
    "            if distances[i, j] < threshold:\n",
    "                keep_mask[j] = False\n",
    "\n",
    "    filtered_persons = persons[keep_mask]\n",
    "    return filtered_persons\n",
    "\n",
    "def generate_personas(alpha_values, n_rm, n_persons, filter_persona_threshold=None, same_alpha=True, random_alpha=False):\n",
    "    all_persons = []\n",
    "    random_state = 42\n",
    "\n",
    "    if same_alpha:\n",
    "        for alpha in alpha_values:\n",
    "            alphas = np.array([alpha] * n_rm)\n",
    "            persons = dirichlet.rvs(alphas, size=n_persons, random_state=random_state)\n",
    "            all_persons.append(persons)\n",
    "\n",
    "    if random_alpha:\n",
    "        alphas = np.random.choice(alpha_values, size=n_rm)\n",
    "        persons = dirichlet.rvs(alphas, size=n_persons, random_state=random_state)\n",
    "        all_persons.append(persons)\n",
    "\n",
    "    all_persons = np.vstack(all_persons)\n",
    "    if filter_persona_threshold:\n",
    "        all_persons = filter_similar_personas(all_persons, filter_persona_threshold)\n",
    "    return all_persons\n",
    "\n",
    "n_persons = 10000\n",
    "alpha_values = [0.1, 0.5, 1.0, 5.0]\n",
    "\n",
    "persons = generate_personas(\n",
    "    alpha_values,\n",
    "    number_personas,\n",
    "    n_persons,\n",
    "    filter_persona_threshold=2e-1,\n",
    "    same_alpha=True,\n",
    "    random_alpha=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "672"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(persons)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.07828757e-03, 8.76649576e-01, 1.69635230e-07, 8.67535647e-12,\n",
       "        1.22271967e-01, 2.71114771e-16],\n",
       "       [9.54650302e-01, 2.35958872e-07, 4.05972179e-05, 1.35078376e-03,\n",
       "        4.39310236e-02, 2.70570526e-05],\n",
       "       [3.59578609e-04, 9.30442171e-08, 4.91740803e-03, 6.32794062e-03,\n",
       "        1.25330942e-12, 9.88394980e-01],\n",
       "       ...,\n",
       "       [8.33887861e-02, 1.20410642e-01, 1.18616173e-01, 3.36975948e-03,\n",
       "        1.72858132e-01, 5.01356508e-01],\n",
       "       [2.57537108e-07, 6.85158004e-01, 1.68795675e-01, 1.45977129e-01,\n",
       "        4.25528425e-06, 6.46793522e-05],\n",
       "       [3.62259030e-05, 7.23197441e-01, 3.92218428e-05, 2.81336460e-04,\n",
       "        1.36040526e-01, 1.40405248e-01]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "persons"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm_rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
