{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1da0f33",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import re\n",
    "import scipy\n",
    "import seaborn as sns\n",
    "\n",
    "from itertools import product\n",
    "from sklearn.cluster import KMeans\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656e0afa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "data_dir = \"../data/AAMAS\"\n",
    "dset_idx = 3\n",
    "which_dset = \"00037-0000000%d.cat\" % dset_idx\n",
    "\n",
    "dset_sizes = [(201, 613), (161, 442), (667, 526)]\n",
    "# Yes, maybe, no answer, no\n",
    "# rating_scores = [1.0, .5, 0.0, -1.0]\n",
    "rating_scores = [1.0, .5, 0.0, 0.01]\n",
    "\n",
    "if dset_idx == 3:\n",
    "    # Yes, maybe, no, conflict\n",
    "    rating_scores = [1.0, .5, 0.01, 0.0]\n",
    "#     rating_scores = [1.0, .5, -1.0, 0.0]\n",
    "\n",
    "agent_idx = 0\n",
    "with open(os.path.join(data_dir, which_dset)) as f:\n",
    "    ratings = np.zeros(dset_sizes[dset_idx-1])\n",
    "    for l in f.readlines():\n",
    "        if not l.startswith(\"#\"):\n",
    "            l = re.sub(\"[0-9]*: \", \"\", l)\n",
    "            bracket_list = re.compile(\"\\{[0-9, ]*\\}|[0-9]+\")\n",
    "            lists = bracket_list.findall(l)\n",
    "            if len(lists) != 4:\n",
    "                print(lists)\n",
    "                print(l)\n",
    "                print(\"PROBLEM\")\n",
    "            for idx, list_of_prefs in enumerate(lists):\n",
    "                if list_of_prefs != '{}':\n",
    "                    list_of_prefs = re.sub(\"[\\{\\}]\", \"\", list_of_prefs)\n",
    "                    prefs = [int(x)-1 for x in list_of_prefs.split(\",\")]\n",
    "                    ratings[agent_idx, prefs] = rating_scores[idx]\n",
    "            agent_idx += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6acdc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dset_idx == 3:\n",
    "    gen = np.random.default_rng(seed=0)\n",
    "    ratings[ratings == 0.01] = np.where(gen.random(ratings[ratings == 0.01].shape) < .98, 0, ratings[ratings == 0.01])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f59852b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(ratings == 0.01)/ratings.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54c05a30",
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bd676f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"ratings_%d.npy\" % dset_idx), ratings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f453acb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Binarize the labels\n",
    "binary_ratings = ratings.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f28c765e",
   "metadata": {},
   "outputs": [],
   "source": [
    "binary_ratings[ratings > .1] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ff99d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(binary_ratings.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dc37be3",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = np.random.default_rng(seed=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ed3c225",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: We can apply a mask to have a held-out test set\n",
    "test_frac = .2\n",
    "coi_mask = np.load(os.path.join(data_dir, \"coi_mask_%d.npy\" % dset_idx))\n",
    "coi_mask = np.ones_like(coi_mask)\n",
    "held_out_for_test = gen.random(binary_ratings.shape) < coi_mask*(binary_ratings > 0)*test_frac"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7706445d",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(held_out_for_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd33a8d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "n_rows, n_cols = binary_ratings.shape\n",
    "k = 20\n",
    "U = torch.tensor(gen.normal(size=(n_rows, k)), requires_grad=True)\n",
    "V = torch.tensor(gen.normal(size=(k, n_cols)), requires_grad=True)\n",
    "\n",
    "# X = torch.tensor(gen.normal(size=(n_rows, n_cols)), requires_grad=True)\n",
    "# y = torch.tensor(gen.normal(size=(n_rows, n_cols)), requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d8dd339",
   "metadata": {},
   "outputs": [],
   "source": [
    "observed_ones = np.where((1-held_out_for_test)*(binary_ratings > .9)*coi_mask)\n",
    "observed_minus_ones = np.where((1-held_out_for_test)*(binary_ratings < .1)*(binary_ratings > 0)*coi_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "693b9e81",
   "metadata": {},
   "outputs": [],
   "source": [
    "xe_loss = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdc532c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss = -torch.sum(torch.log(torch.special.expit(X[observed_ones])))\n",
    "# loss = -torch.sum(torch.log(torch.special.expit(1-X[observed_minus_ones])))\n",
    "# loss = xe_loss(X[observed_ones], torch.ones(X[observed_ones].shape))\n",
    "# # loss += 0.1 * cp.norm(X, 'nuc')\n",
    "# loss += .1 * torch.trace(torch.sqrt(X.T @ X))\n",
    "\n",
    "step_size = .1\n",
    "optimizer = torch.optim.Adam([U, V], lr=step_size)\n",
    "# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f367fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in tqdm(range(int(3001))):\n",
    "    optimizer.zero_grad()\n",
    "    x_pred = torch.special.expit(torch.mm(U, V))\n",
    "    tgt_ones = torch.ones(ratings[observed_ones].shape).to(torch.long)\n",
    "    loss = xe_loss(torch.concat((-x_pred[observed_ones].reshape((-1,1)), x_pred[observed_ones].reshape((-1, 1))), axis=1), tgt_ones)\n",
    "    tgt_zeros = torch.zeros(ratings[observed_minus_ones].shape).to(torch.long)\n",
    "    loss += xe_loss(torch.concat((-x_pred[observed_minus_ones].reshape((-1,1)), x_pred[observed_minus_ones].reshape((-1, 1))), axis=1), tgt_zeros)\n",
    "    loss += .0005*torch.norm(U)\n",
    "    loss += .0005*torch.norm(V)\n",
    "#     loss = xe_loss(torch.concat((y[observed_ones].reshape((1, -1)), X[observed_ones].reshape((1, -1))), axis=0).T, torch.ones(X[observed_ones].shape).to(torch.long))\n",
    "# #     loss += xe_loss(X[observed_minus_ones], torch.zeros(X[observed_minus_ones].shape))\n",
    "#     loss += xe_loss(torch.concat((y[observed_minus_ones].reshape((1, -1)), X[observed_minus_ones].reshape((1, -1))), axis=0).T, torch.zeros(X[observed_minus_ones].shape).to(torch.long))\n",
    "\n",
    "#     loss = -torch.sum(torch.log(torch.special.expit(1-X[observed_minus_ones])))\n",
    "    # loss += 0.1 * cp.norm(X, 'nuc')\n",
    "#     loss += .1 * torch.trace(torch.sqrt(torch.mm(X.T, X)))\n",
    "    if i % 1000 == 0:\n",
    "        print(loss)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "#     scheduler.step(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddf4bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_mat = torch.special.expit(torch.mm(U, V))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf81a80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(pred_mat[observed_ones].flatten().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd4a5b4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(pred_mat[observed_minus_ones].flatten().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d8ed898",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(pred_mat.flatten().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3dd9e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now compute the bound on the cross entropy loss for each group under each delta value\n",
    "groups = np.load(os.path.join(data_dir, \"groups_%d.npy\" % dset_idx))\n",
    "coi_mask = np.load(os.path.join(data_dir, \"coi_mask_%d.npy\" % dset_idx))\n",
    "coi_mask = np.ones_like(coi_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c486c60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For each group, pull out the entries that were held out. Then compute mean and standard deviation of cross entropy loss\n",
    "ngroups = len(set(groups))\n",
    "\n",
    "means = []\n",
    "stds = []\n",
    "\n",
    "nonaggloss = torch.nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "for gidx in range(ngroups):\n",
    "    cm = coi_mask[:, np.where(groups==gidx)[0]]\n",
    "    test_entries = binary_ratings[:, np.where(groups == gidx)[0]]\n",
    "    predictions = pred_mat[:, np.where(groups == gidx)[0]]\n",
    "    observed_entries = np.where((np.abs(test_entries) > 0)*cm)\n",
    "    print(observed_entries[0].shape)\n",
    "#     print(test_entries[observed_entries])\n",
    "#     print(predictions[observed_entries])\n",
    "    tgt = test_entries[observed_entries]\n",
    "    tgt[tgt == .01] = 0\n",
    "    tgt = torch.tensor(tgt, dtype=int)\n",
    "    \n",
    "    p = predictions[observed_entries]\n",
    "    p0 = torch.reshape(1-p, (-1,1))\n",
    "    p1 = torch.reshape(p, (-1,1))\n",
    "    preds_tensor = torch.concat((p0, p1), 1)\n",
    "#     print(preds_tensor)\n",
    "    preds_tensor = torch.log(preds_tensor)\n",
    "#     print(preds_tensor.shape)\n",
    "    \n",
    "    xe_test = nonaggloss(preds_tensor, tgt)\n",
    "    xe_test = xe_test.detach().cpu().numpy()\n",
    "    \n",
    "    mean, std = np.mean(xe_test), np.std(xe_test, ddof=1)\n",
    "    print(mean, std)\n",
    "    means.append(mean)\n",
    "    stds.append(std)\n",
    "#     print(test_entries.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11576d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_groups = len(set(groups))\n",
    "\n",
    "delta_to_normal_bd = {}\n",
    "for delta in [.3, .2, .1, .05, .01]:\n",
    "    delta_to_normal_bd[delta] = []\n",
    "    for gidx in range(n_groups):\n",
    "        ub = scipy.stats.norm.ppf(1-(delta/n_groups), loc=means[gidx], scale=stds[gidx])\n",
    "        delta_to_normal_bd[delta].append(ub)\n",
    "#     xis.append(xe_test.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae27f29f",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_to_normal_bd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f95ee18",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"prob_up_%d.npy\" % dset_idx), pred_mat.detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "586b6c71",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(delta_to_normal_bd, open(os.path.join(data_dir, \"delta_to_normal_bd_%d.pkl\" % dset_idx), 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b45af204",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"coi_mask_%d.npy\" % dset_idx), coi_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ca976a6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e357734",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a5e2412",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_group_egal_linear(a_l, b_l, phat_l, C_l, rhs_bd_per_group, loads, covs_lb_l, covs_ub_l, milp=False):\n",
    "    ngroups = len(phat_l)\n",
    "    model = gp.Model()\n",
    "\n",
    "    t = model.addVar(lb=-gp.GRB.INFINITY, ub=gp.GRB.INFINITY,vtype=gp.GRB.CONTINUOUS, name='t')\n",
    "\n",
    "    e_vals = []\n",
    "    c_vals = []\n",
    "    f_vals = []\n",
    "    x_vals = []\n",
    "    Allocs = []\n",
    "\n",
    "    for gdx in range(ngroups):\n",
    "        n_agents = phat_l[gdx].shape[0]\n",
    "        n_items = phat_l[gdx].shape[1]\n",
    "        phat = phat_l[gdx].flatten()\n",
    "        C = C_l[gdx].flatten()\n",
    "        covs_lb = covs_lb_l[gdx].flatten()\n",
    "        covs_ub = covs_ub_l[gdx].flatten()\n",
    "\n",
    "        A_multiplier = (a_l[gdx] - b_l[gdx])\n",
    "        if milp==False:\n",
    "            A = model.addMVar(len(phat_l[gdx].flatten()),lb=0, ub=1, vtype=gp.GRB.CONTINUOUS, name='Alloc' + str(gdx))\n",
    "        else:\n",
    "            A = model.addMVar(len(phat_l[gdx].flatten()),lb=0, ub=1, vtype=gp.GRB.INTEGER, name='Alloc' + str(gdx))\n",
    "        Allocs.append(A)\n",
    "\n",
    "        eps = 1e-6\n",
    "\n",
    "        log_p_phat = np.log(phat).flatten()\n",
    "        log_one_minus_phat = np.log(1-phat).flatten()\n",
    "        rhs_bd = rhs_bd_per_group[gdx]\n",
    "\n",
    "        mn = int(n_agents*n_items)\n",
    "        c_val = np.sum(C)\n",
    "\n",
    "        e = -1.0 * (c_val * rhs_bd + np.sum(C*log_one_minus_phat))\n",
    "        neg_ones = -1*np.ones(mn)\n",
    "        c= np.vstack((np.array([e]).reshape(1,1),neg_ones.flatten().reshape(-1,1))).flatten()\n",
    "        f =  C*(log_p_phat - log_one_minus_phat).flatten()\n",
    "\n",
    "        x = model.addMVar(mn+1, lb=0, ub=gp.GRB.INFINITY, vtype=gp.GRB.CONTINUOUS, name=\"pval\")\n",
    "        e_vals.append(e)\n",
    "        c_vals.append(c)\n",
    "        f_vals.append(f)\n",
    "        x_vals.append(x)\n",
    "\n",
    "        model.addConstrs(A[i] <= C[i] for i in range(mn))\n",
    "\n",
    "        model.addConstrs(gp.quicksum(A[jdx * n_items + idx] for jdx in range(n_agents)) <= covs_ub[idx] for idx in\n",
    "                         range(n_items))\n",
    "\n",
    "        model.addConstrs(gp.quicksum(A[jdx * n_items + idx] for jdx in range(n_agents)) >= covs_lb[idx] for idx in\n",
    "                         range(n_items))\n",
    "\n",
    "        model.addConstrs((f[jdx]*x[0] - x[jdx+1] <= A_multiplier*A[jdx]/n_items   for jdx in range(mn)),name='ctr'+ str(gdx))\n",
    "        model.addConstr(t<= c@x, name='min_w'+ str(gdx))\n",
    "\n",
    "    load_sum = model.addMVar(loads.size, lb=0, ub=gp.GRB.INFINITY, obj=0.0, vtype=gp.GRB.CONTINUOUS, name='load_sum')\n",
    "\n",
    "    model.addConstrs(load_sum[idx] == gp.quicksum(\n",
    "        Allocs[gdx][idx * phat_l[gdx].shape[1]:(idx + 1) * (phat_l[gdx].shape[1])].sum() for gdx in range(ngroups)) for\n",
    "                     idx in range(loads.size))\n",
    "    total_agents = loads.size\n",
    "    model.addConstrs(load_sum[idx] <= loads[idx] for idx in range(total_agents))\n",
    "\n",
    "    model.setObjective(t, gp.GRB.MAXIMIZE)\n",
    "    model.setParam('OutputFlag', 1)\n",
    "\n",
    "    model.optimize()\n",
    "    final_allocs = []\n",
    "    for idx in range(ngroups):\n",
    "        final_allocs.append(Allocs[idx].X)\n",
    "\n",
    "    obj = model.getObjective()\n",
    "\n",
    "    return final_allocs, obj.getValue()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3538bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_revs = gen.choice(range(pred_mat.shape[0]), int(.1*pred_mat.shape[0]))\n",
    "selected_paps = gen.choice(range(pred_mat.shape[1]), int(.1*pred_mat.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d798218",
   "metadata": {},
   "outputs": [],
   "source": [
    "coi_mask = np.load(os.path.join(data_dir, \"coi_mask_%d.npy\" % dset_idx))\n",
    "coi_mask = np.ones_like(coi_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e7a5f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_pred_mat = pred_mat[selected_revs, :][:, selected_paps].detach().numpy()\n",
    "small_coi = coi_mask[selected_revs, :][:, selected_paps]\n",
    "small_groups = groups[selected_paps]\n",
    "phat_l = []\n",
    "C_l = []\n",
    "cov_l = []\n",
    "\n",
    "for gidx in range(4):\n",
    "    phat_l.append(small_pred_mat[:, small_groups == gidx])\n",
    "    C_l.append(small_coi[:, small_groups == gidx])\n",
    "    cov_l.append(np.array([2]*np.sum(small_groups == gidx)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78cda987",
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_group_egal_linear([1]*4, [0]*4, phat_l, C_l, [0.9394841194775468,\n",
    "  0.8941658975527695,\n",
    "  0.990867096534835,\n",
    "  1.0065408293391225], np.array([10]*small_pred_mat.shape[0]), cov_l, cov_l, milp=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45f43bab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ffaeb9e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0be3f6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d053d286",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use https://dl.acm.org/doi/pdf/10.1145/1553374.1553452 to get prob model for CVaR objectives\n",
    "gen = np.random.default_rng(seed=0)\n",
    "q = 20\n",
    "x = gen.normal(loc=0, scale=1e-3, size=(ratings.shape[0], q))\n",
    "sig = .05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53d04593",
   "metadata": {},
   "outputs": [],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea8c7f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_iter = 100\n",
    "\n",
    "lr = 1e-4\n",
    "\n",
    "for i in tqdm(range(n_iter)):\n",
    "#     item_idx = gen.integers(ratings.shape[1])\n",
    "    for item_idx in range(ratings.shape[1]):\n",
    "        # pick out the users where this item was rated\n",
    "        observed = np.where(np.abs(ratings[:, item_idx]) > 1e-4)[0]\n",
    "#         print(observed)\n",
    "#         print(x[observed])\n",
    "        if len(observed):\n",
    "            Cj = np.matmul(x[observed], x[observed].T) + (sig**2)*np.eye(len(observed))\n",
    "#             print(Cj)\n",
    "            Cinv = np.linalg.inv(Cj)\n",
    "            yobs = ratings[observed, item_idx]\n",
    "    #         print(yobs)\n",
    "            G = np.outer(yobs, yobs)\n",
    "            G = np.matmul(Cinv, G)\n",
    "            G = np.matmul(G, Cinv)\n",
    "            G -= Cinv\n",
    "            grad = np.matmul(-G, x[observed])\n",
    "\n",
    "            x[observed] -= lr*grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7245fdc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For each item, we can sample the rating for each user. It is by section 3.1, the prediction of user rating section.\n",
    "print(x.shape)\n",
    "k = np.matmul(x, x.T)\n",
    "s = k + (sig**2)*np.eye(k.shape[0])\n",
    "mu_matrix = np.zeros(ratings.shape)\n",
    "zeta_matrix = np.zeros(ratings.shape)\n",
    "for item_idx in tqdm(range(ratings.shape[1])):\n",
    "    for user_idx in range(ratings.shape[0]):\n",
    "        observed = np.where(np.abs(ratings[:, item_idx]) > 1e-4)[0]\n",
    "        sobs = s[observed, :][:, observed]\n",
    "        final_s = np.matmul(np.linalg.inv(sobs), k[observed, user_idx])\n",
    "        mu = np.dot(final_s, ratings[observed, item_idx])\n",
    "        mu_matrix[user_idx, item_idx] = mu\n",
    "        zeta_matrix[user_idx, item_idx] = k[user_idx, user_idx] + sig**2 - np.dot(k[observed, user_idx], final_s)\n",
    "        \n",
    "# cov_mat = np.matmul(x, x.T) + (sig**2)*np.eye(x.shape[0])\n",
    "# zero_vec = np.zeros(x.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a16b872",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_clip = np.clip(mu_matrix, 0.01, np.inf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40053f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"mu_matrix_%d.npy\" % dset_idx), mu_clip)\n",
    "np.save(os.path.join(data_dir, \"zeta_matrix_%d.npy\" % dset_idx), zeta_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e3f9ecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(mu_clip.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a9af06",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(ratings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2059493",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c2703a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings[np.where(ratings > .5)].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a8b1d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_matrix[np.where(ratings == 0)][:300]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59f68e1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(mu_matrix.flatten(), bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "812e1fb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(zeta_matrix.flatten(), bins=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c642949a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(zeta_matrix, axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83ad3a1b",
   "metadata": {},
   "source": [
    "# Now we'll cluster these and make groups for both getting COIs and for the GESW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dfec1c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cluster ratings\n",
    "dset_idx = 3\n",
    "ratings = np.load(os.path.join(data_dir, \"ratings_%d.npy\" % dset_idx))\n",
    "\n",
    "ratings[ratings < 0] = 0\n",
    "ratings[ratings > 0] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f8d13bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(ratings.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f0017e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.cluster import SpectralBiclustering, KMeans\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.manifold import SpectralEmbedding\n",
    "n, m = ratings.shape\n",
    "affinities = np.zeros((n+m, n+m))\n",
    "affinities[:n, :][:, n:] = ratings\n",
    "affinities[n:, :][:, :n] = ratings.T\n",
    "affinities += 1e-5\n",
    "embedding = SpectralEmbedding(n_components=5, affinity='precomputed')\n",
    "X_transformed = embedding.fit_transform(affinities)\n",
    "# >>> X_transformed.shape\n",
    "# clustering = SpectralBiclustering(n_clusters=4, random_state=1, method=\"log\").fit(ratings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "228860f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_transformed.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "420f9a0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "k=4\n",
    "clusters = KMeans(n_clusters=k, random_state=0).fit(X_transformed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a99712f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rc = clusters.labels_[:n]\n",
    "cc = clusters.labels_[n:]\n",
    "print(Counter(rc), Counter(cc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56390f75",
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters.cluster_centers_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c886559",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gurobipy as gp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7879497",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate between assigning row and column cluster ids where each cluster must have at least a certain number of points,\n",
    "# then recompute cluster centers.\n",
    "centers = clusters.cluster_centers_\n",
    "row_embs = X_transformed[:n]\n",
    "col_embs = X_transformed[n:]\n",
    "\n",
    "min_rc_size = int(.8*(n/k))\n",
    "min_cc_size = int(.8*(m/k))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "324118c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(min_rc_size, min_cc_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85fb8914",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(100):\n",
    "    # First we reassign with balance contraints\n",
    "    row_dist_to_ctrs = cdist(row_embs, centers)\n",
    "    col_dist_to_ctrs = cdist(col_embs, centers)\n",
    "    \n",
    "    m = gp.Model()\n",
    "\n",
    "    matching_rows = m.addMVar(row_dist_to_ctrs.shape, vtype=gp.GRB.BINARY)\n",
    "    matching_cols = m.addMVar(col_dist_to_ctrs.shape, vtype=gp.GRB.BINARY)\n",
    "    \n",
    "    m.addConstr(matching_rows.sum(axis=1) == 1)\n",
    "    m.addConstr(matching_rows.sum(axis=0) >= min_rc_size)\n",
    "    \n",
    "    m.addConstr(matching_cols.sum(axis=1) == 1)\n",
    "    m.addConstr(matching_cols.sum(axis=0) >= min_cc_size)\n",
    "\n",
    "    obj = (matching_rows*row_dist_to_ctrs).sum() + (matching_cols*col_dist_to_ctrs).sum()\n",
    "\n",
    "    m.setObjective(obj)\n",
    "    m.optimize()\n",
    "    \n",
    "    print(obj.getValue())\n",
    "\n",
    "    row_clusters = np.where(matching_rows.x)[1]\n",
    "    col_clusters = np.where(matching_cols.x)[1]\n",
    "    \n",
    "    # Then we recompute cluster centers\n",
    "    for cidx in range(k):\n",
    "        row_pts = row_embs[np.where(row_clusters == cidx)[0], :]\n",
    "        col_pts = col_embs[np.where(col_clusters == cidx)[0], :]\n",
    "        centers[cidx] = np.mean(np.vstack((row_pts, col_pts)), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea58637c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Counter(row_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1204a25b",
   "metadata": {},
   "outputs": [],
   "source": [
    "Counter(col_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bfa5b97",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"groups_%d.npy\" % dset_idx), col_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "088dc3ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "coi_mask = (np.reshape(col_clusters, (1,-1)) == np.reshape(row_clusters, (-1, 1))).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de0792c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(coi_mask)/coi_mask.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dec34bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.sum(coi_mask, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3902e1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb221446",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"coi_mask_%d.npy\" % dset_idx), coi_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09b785c3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d26a286",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d018e4bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ffa48f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 4\n",
    "min_cluster_size = ratings.shape[0]//(k+1)\n",
    "print(min_cluster_size)\n",
    "\n",
    "row_clusters = KMeans(n_clusters=4, random_state=0).fit(ratings)\n",
    "row_centers = row_clusters.cluster_centers_\n",
    "distance_matrix = cdist(ratings, row_centers)\n",
    "print(distance_matrix.shape)\n",
    "# Now match each row to a column such that the total number of rows matching to columns is roughly equal.\n",
    "m = gp.Model()\n",
    "\n",
    "matching = m.addMVar(distance_matrix.shape, vtype=gp.GRB.BINARY)\n",
    "m.addConstr(matching.sum(axis=1) == 1)\n",
    "m.addConstr(matching.sum(axis=0) >= min_cluster_size)\n",
    "\n",
    "obj = (matching*distance_matrix).sum()\n",
    "\n",
    "m.setObjective(obj)\n",
    "m.optimize()\n",
    "\n",
    "row_clusters = np.where(matching.x)[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ef88f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Do the same for the columns\n",
    "k = 4\n",
    "min_cluster_size = ratings.shape[1]//(k+1)\n",
    "print(min_cluster_size)\n",
    "\n",
    "col_clusters = KMeans(n_clusters=4, random_state=0).fit(ratings.T)\n",
    "col_centers = col_clusters.cluster_centers_\n",
    "distance_matrix = cdist(ratings.T, col_centers)\n",
    "print(distance_matrix.shape)\n",
    "# Now match each row to a column such that the total number of rows matching to columns is roughly equal.\n",
    "import gurobipy as gp\n",
    "m = gp.Model()\n",
    "\n",
    "matching = m.addMVar(distance_matrix.shape, vtype=gp.GRB.BINARY)\n",
    "m.addConstr(matching.sum(axis=1) == 1)\n",
    "m.addConstr(matching.sum(axis=0) >= min_cluster_size)\n",
    "\n",
    "obj = (matching*distance_matrix).sum()\n",
    "\n",
    "m.setObjective(obj)\n",
    "m.optimize()\n",
    "\n",
    "col_clusters = np.where(matching.x)[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773ede4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Counter(row_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6261b777",
   "metadata": {},
   "outputs": [],
   "source": [
    "Counter(col_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35ba1466",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Match up the row clusters with the column clusters. \n",
    "# For each pair of clusters, we take the submatrix corresponding to those rows and columns. We can then\n",
    "# compute the average similarity. Then compute a maximum matching basically.\n",
    "cluster_link_scores = np.zeros((k,k))\n",
    "for ridx, cidx in product(range(k), range(k)):\n",
    "    cluster_link_scores[ridx, cidx] = np.mean(ratings[row_clusters == ridx, :][:, col_clusters == cidx])\n",
    "print(cluster_link_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb5d05f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import linear_sum_assignment\n",
    "row_ind, col_ind = linear_sum_assignment(-1*cluster_link_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afabd2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(row_ind, col_ind)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1d603b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_col_inds = row_ind[np.argsort(col_ind)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f6050af",
   "metadata": {},
   "outputs": [],
   "source": [
    "remapped_col_clusters = [new_col_inds[i] for i in col_clusters]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db60c3a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "Counter(remapped_col_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06140e26",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(ratings[:, 13] > .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5817b9c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c903a6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"groups_%d.npy\" % dset_idx), remapped_col_clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "924e2b9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "coi_mask = (np.reshape(remapped_col_clusters, (1,-1)) == np.reshape(row_clusters, (-1, 1))).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e810b0f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(coi_mask)/coi_mask.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06519583",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.sum(coi_mask, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f381e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff2abe7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(os.path.join(data_dir, \"coi_mask_%d.npy\" % dset_idx), coi_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1344c970",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gurobipy as gp\n",
    "m = gp.Model(\"TPMS\")\n",
    "\n",
    "mu_matrix = np.load(\"../data/AAMAS/mu_matrix_%d.npy\" % dset_idx)\n",
    "coi_mask = np.load(\"../data/AAMAS/coi_mask_%d.npy\" % dset_idx)\n",
    "\n",
    "cs = [3, 2, 2]\n",
    "ls = [15, 15, 4]\n",
    "covs_lb = cs[dset_idx-1] * np.ones(mu_matrix.shape[1])\n",
    "covs_ub = covs_lb\n",
    "loads = ls[dset_idx - 1] * np.ones(mu_matrix.shape[0])\n",
    "\n",
    "covs_lb = np.minimum(covs_lb, np.sum(coi_mask, axis=0))\n",
    "\n",
    "\n",
    "alloc = m.addMVar(mu_matrix.shape, vtype=gp.GRB.BINARY, name='alloc')\n",
    "\n",
    "m.addConstr(alloc.sum(axis=0) >= covs_lb)\n",
    "m.addConstr(alloc.sum(axis=0) <= covs_ub)\n",
    "m.addConstr(alloc.sum(axis=1) <= loads)\n",
    "m.addConstr(alloc <= coi_mask)\n",
    "\n",
    "obj = (alloc*mu_matrix).sum()\n",
    "m.setObjective(obj, gp.GRB.MAXIMIZE)\n",
    "\n",
    "m.optimize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7cbfc57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.sum(coi_mask, axis=0)\n",
    "np.sum(alloc.x, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdccf902",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(coi_mask, axis=0)[:30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f02031c",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(mu_matrix > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18cd9782",
   "metadata": {},
   "outputs": [],
   "source": [
    "(alloc.x*mu_matrix)[np.where(alloc.x > .5)][:50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a7685e",
   "metadata": {},
   "outputs": [],
   "source": [
    "alloc.x[np.where(alloc.x > .5)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "602dd9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(ratings.flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5e6a4f1",
   "metadata": {},
   "source": [
    "# Lets figure out what the uncertainty sets look like for each of these 3 conferences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c32677",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dset_idx = 1\n",
    "mu = np.load(os.path.join(data_dir, \"mu_matrix_%d.npy\" % dset_idx))\n",
    "sig = np.load(os.path.join(data_dir, \"zeta_matrix_%d.npy\" % dset_idx))\n",
    "cois = np.load(os.path.join())\n",
    "print(mu, sig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "325d9847",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "14faf10a",
   "metadata": {},
   "source": [
    "# Old code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4e9107",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use logistic matrix factorization to fill in the missing entries\n",
    "n_rows, n_cols = binary_ratings.shape\n",
    "\n",
    "X = cp.Variable((n_rows, n_cols))\n",
    "\n",
    "observed_ones = np.where(binary_ratings > .9)\n",
    "observed_minus_ones = np.where(binary_ratings < -.9)\n",
    "\n",
    "loss = -cp.sum(cp.logistic(-X[observed_ones]))\n",
    "loss -= cp.sum(cp.logistic(X[observed_minus_ones]))\n",
    "loss -= 0.1 * cp.norm(X, 'nuc')\n",
    "\n",
    "print(loss.is_dcp())\n",
    "\n",
    "objective = cp.Maximize(loss)\n",
    "problem = cp.Problem(objective)\n",
    "problem.solve(solver=cp.MOSEK, verbose=True)\n",
    "\n",
    "recovered_matrix = X.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b18188",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_adv_usw_ellipsoidal(allocation, central_estimate, std_devs, rhs_bd_per_group, groups):\n",
    "    m = gp.Model()\n",
    "\n",
    "    ngroups = len(set(groups))\n",
    "\n",
    "    obj_terms = []\n",
    "\n",
    "    for gidx in range(ngroups):\n",
    "        print(\"setting up group \", gidx)\n",
    "        gmask = np.where(groups == gidx)[0]\n",
    "\n",
    "        a = allocation[:, gmask]\n",
    "        ce = central_estimate[:, gmask]\n",
    "        sd = std_devs[:, gmask]\n",
    "        rhs_bd = rhs_bd_per_group[gidx]\n",
    "\n",
    "        v = m.addMVar(ce.shape)\n",
    "\n",
    "        m.addConstr(((v - ce)*(1/sd)*(v-ce)).sum() <= rhs_bd**2)\n",
    "\n",
    "        m.addConstr(v >= 0)\n",
    "        obj_terms.append((a * v).sum())\n",
    "    obj = gp.quicksum(t for t in obj_terms)\n",
    "    m.setObjective(obj)\n",
    "    m.optimize()\n",
    "    m.setParam('OutputFlag', 1)\n",
    "\n",
    "    return obj.getValue()/allocation.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73cfbea1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ce = np.load(\"../data/AAMAS/mu_matrix_1.npy\")\n",
    "sd = np.load(\"../data/AAMAS/zeta_matrix_1.npy\")\n",
    "coi_mask = np.load(\"../data/AAMAS/coi_mask_1.npy\")\n",
    "# rhs_bd_per_group = pickle.load(open(\"../data/AAMAS/delta_to_normal_bd_1.pkl\", 'rb'))[.3]\n",
    "groups = np.load(\"../data/AAMAS/groups_1.npy\")\n",
    "\n",
    "rhs_bd_per_group = []\n",
    "for gidx in range(ngroups):\n",
    "    gmask = np.where(groups == gidx)[0]\n",
    "    c_value = np.sum(coi_mask[:, gmask])\n",
    "    rhs_bd_per_group.append(np.sqrt(chi2.ppf(1-(delta/ngroups), df=c_value)))\n",
    "alloc = np.load(\"../outputs/outputs/gAAMAS1/adv_usw_0.20_alloc.npy\")\n",
    "compute_adv_usw_ellipsoidal(alloc, ce, sd, rhs_bd_per_group, groups)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10fca351",
   "metadata": {},
   "outputs": [],
   "source": [
    "ce = np.load(\"../data/AAMAS/mu_matrix_1.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef60a37",
   "metadata": {},
   "outputs": [],
   "source": [
    "std_devs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "070a59f5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
