{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_gt(theta_init, KL_lambda, w, rm):\n",
    "    # theta_init (1, outcome_size) KL_lambda (1, ) w (B, group_num, 1) rm (B, group_num, outcome_size)\n",
    "    theta = theta_init * torch.exp(1 / KL_lambda * (w * rm).sum(1, keepdim=True))\n",
    "    theta = theta / theta.sum(-1, keepdim=True)\n",
    "    return theta \n",
    "\n",
    "def kl_distance(theta, theta_init):\n",
    "    # theta (B, 1, outcome_size) theta_init (1, outcome_size)\n",
    "    return (theta * torch.log(theta / theta_init.unsqueeze(0))).sum(-1, keepdim=True).squeeze()\n",
    "\n",
    "def adjust_group_zero(tensor, epsilon=0.0):\n",
    "    result = tensor.clone()\n",
    "    \n",
    "    group_zero = result[:, 0, :]\n",
    "    \n",
    "    min_vals = group_zero.min(dim=1, keepdim=True)[0]  # [batch_size, 1]\n",
    "    max_vals = group_zero.max(dim=1, keepdim=True)[0]  # [batch_size, 1]\n",
    "    \n",
    "    min_mask = (group_zero == min_vals).float()\n",
    "    max_mask = (group_zero == max_vals).float()\n",
    "    \n",
    "    if epsilon == 0:\n",
    "        epsilon = min_vals\n",
    "    else:\n",
    "        epsilon = torch.where(min_vals < epsilon, min_vals, epsilon)\n",
    "    adjustment = -epsilon * min_mask + epsilon * max_mask\n",
    "    group_zero += adjustment\n",
    "    \n",
    "    result[:, 0, :] = group_zero\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:01<00:00, 805.18it/s]\n",
      "100%|██████████| 1000/1000 [00:01<00:00, 758.88it/s]\n",
      "100%|██████████| 1000/1000 [00:01<00:00, 762.68it/s]\n",
      "100%|██████████| 1000/1000 [00:01<00:00, 808.81it/s]\n",
      "100%|██████████| 1000/1000 [00:01<00:00, 756.04it/s]\n",
      "100%|██████████| 1000/1000 [00:01<00:00, 820.97it/s]\n",
      "100%|██████████| 1000/1000 [00:01<00:00, 818.53it/s]\n"
     ]
    }
   ],
   "source": [
    "outcome_size = 10\n",
    "batch_size = 4096\n",
    "theta_init = torch.ones(1, outcome_size).to('cpu') / outcome_size\n",
    "group_num = 3\n",
    "KL_lambda = 0.2\n",
    "results = []\n",
    "for epsilon in [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1]:\n",
    "    values, utilities = [], []\n",
    "    for i in tqdm(range(1000)):\n",
    "        w = torch.randint(1, 11, (batch_size, group_num, 1)).to('cpu')\n",
    "        rm = torch.rand(batch_size, group_num, outcome_size).to('cpu')\n",
    "        rm = rm / rm.sum(-1, keepdim=True)\n",
    "        \n",
    "        theta = calculate_gt(theta_init, KL_lambda, w, rm) # (B, 1, outcome_size)\n",
    "        value = (w[:, [0]] * rm[:, [0]] * theta).sum(-1)\n",
    "        SW_minus = (w[:, 1:] * rm[:, 1:] * theta).sum(-1).sum(-1) - KL_lambda * kl_distance(theta, theta_init) # (B, 1, 1)\n",
    "        theta_minus = calculate_gt(theta_init, KL_lambda, w[:, 1:], rm[:, 1:])\n",
    "        SW_minus_star = (w[:, 1:] * rm[:, 1:] * theta_minus).sum(-1).sum(-1) - KL_lambda * kl_distance(theta_minus, theta_init)\n",
    "        # print(value.shape, SW_minus_star.shape, SW_minus.shape, (w[:, 1:] * rm[:, 1:] * theta_minus).sum(-1).sum(-1).shape, (KL_lambda * kl_distance(theta_minus, theta_init)).shape)\n",
    "        utility = value.squeeze() - (SW_minus_star - SW_minus)\n",
    "        # print(utility.mean(), utility.min(), utility.max())\n",
    "        # print(utility.shape)\n",
    "\n",
    "        rm_adjusted = adjust_group_zero(rm, epsilon)\n",
    "        theta = calculate_gt(theta_init, KL_lambda, w, rm_adjusted)\n",
    "        value_adjusted = (w[:, [0]] * rm[:, [0]] * theta).sum(-1)\n",
    "        SW_minus = (w[:, 1:] * rm[:, 1:] * theta).sum(-1).sum(-1) - KL_lambda * kl_distance(theta, theta_init)\n",
    "        # print(SW_minus, SW_minus_star)\n",
    "        utility_adjusted = value_adjusted.squeeze() - (SW_minus_star -  SW_minus)\n",
    "        # print(utility.mean(), utility.min(), utility.max())\n",
    "    \n",
    "        value_gap = value - value_adjusted\n",
    "        utility_gap = utility - utility_adjusted\n",
    "        \n",
    "        values.append(value_gap)\n",
    "        utilities.append(utility_gap)\n",
    "        # print(value_gap.mean(), value_gap.min(), value_gap.max())\n",
    "        # print(utility_gap.mean(), utility_gap.min(), utility_gap.max())\n",
    "    values = torch.stack(values)\n",
    "    utilities = torch.stack(utilities)\n",
    "    # print(values.mean(), values.min(), values.max())\n",
    "    # print(utilities.mean(), utilities.min(), utilities.max())\n",
    "    results.append([values.mean(), values.max(), utilities.mean(), utilities.min()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[tensor(-0.0011),\n",
       "  tensor(3.5763e-07),\n",
       "  tensor(1.0647e-05),\n",
       "  tensor(-1.9073e-06)],\n",
       " [tensor(-0.0021),\n",
       "  tensor(4.7684e-07),\n",
       "  tensor(4.1365e-05),\n",
       "  tensor(-1.3113e-06)],\n",
       " [tensor(-0.0049), tensor(3.5763e-07), tensor(0.0002), tensor(-1.5497e-06)],\n",
       " [tensor(-0.0087), tensor(4.7684e-07), tensor(0.0008), tensor(-1.6689e-06)],\n",
       " [tensor(-0.0137), tensor(4.7684e-07), tensor(0.0024), tensor(-1.3113e-06)],\n",
       " [tensor(-0.0180), tensor(4.7684e-07), tensor(0.0053), tensor(-1.3113e-06)],\n",
       " [tensor(-0.0182), tensor(4.7684e-07), tensor(0.0056), tensor(-1.2517e-06)]]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:01<00:00, 806.30it/s]\n"
     ]
    }
   ],
   "source": [
    "outcome_size = 10\n",
    "batch_size = 4096\n",
    "theta_init = torch.ones(1, outcome_size).to('cpu') / outcome_size\n",
    "group_num = 3\n",
    "KL_lambda = 0.2\n",
    "results = []\n",
    "values, utilities = [], []\n",
    "for i in tqdm(range(1000)):\n",
    "    w = torch.randint(1, 11, (batch_size, group_num, 1)).to('cpu')\n",
    "    rm = torch.rand(batch_size, group_num, outcome_size).to('cpu')\n",
    "    rm = rm / rm.sum(-1, keepdim=True)\n",
    "    \n",
    "    theta = calculate_gt(theta_init, KL_lambda, w, rm) # (B, 1, outcome_size)\n",
    "    value = (w[:, [0]] * rm[:, [0]] * theta).sum(-1)\n",
    "    SW_minus = (w[:, 1:] * rm[:, 1:] * theta).sum(-1).sum(-1) - KL_lambda * kl_distance(theta, theta_init) # (B, 1, 1)\n",
    "    theta_minus = calculate_gt(theta_init, KL_lambda, w[:, 1:], rm[:, 1:])\n",
    "    SW_minus_star = (w[:, 1:] * rm[:, 1:] * theta_minus).sum(-1).sum(-1) - KL_lambda * kl_distance(theta_minus, theta_init)\n",
    "    # print(value.shape, SW_minus_star.shape, SW_minus.shape, (w[:, 1:] * rm[:, 1:] * theta_minus).sum(-1).sum(-1).shape, (KL_lambda * kl_distance(theta_minus, theta_init)).shape)\n",
    "    utility = value.squeeze() - (SW_minus_star - SW_minus)\n",
    "    # print(utility.mean(), utility.min(), utility.max())\n",
    "    # print(utility.shape)\n",
    "\n",
    "    rm_adjusted = adjust_group_zero(rm, epsilon)\n",
    "    theta = calculate_gt(theta_init, KL_lambda, w, rm_adjusted)\n",
    "    value_adjusted = (w[:, [0]] * rm[:, [0]] * theta).sum(-1)\n",
    "    SW_minus = (w[:, 1:] * rm[:, 1:] * theta).sum(-1).sum(-1) - KL_lambda * kl_distance(theta, theta_init)\n",
    "    # print(SW_minus, SW_minus_star)\n",
    "    utility_adjusted = value_adjusted.squeeze() - (SW_minus_star -  SW_minus)\n",
    "    # print(utility.mean(), utility.min(), utility.max())\n",
    "\n",
    "    value_gap = value - value_adjusted\n",
    "    utility_gap = utility - utility_adjusted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_gt_maxmin(theta_init, KL_lambda, w, rm):\n",
    "    # theta_init (outcome_size) KL_lambda (1, ) w (group_num, 1) rm (group_num, outcome_size)\n",
    "    theta = theta_init * torch.exp(1 / KL_lambda * (w * rm).min(0).values)\n",
    "    theta = theta / theta.sum()\n",
    "    return theta \n",
    "\n",
    "def kl_distance(theta, theta_init):\n",
    "    # theta (outcome_size) theta_init (outcome_size)\n",
    "    return (theta * torch.log(theta / theta_init)).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "outcome_size = 10\n",
    "theta_init = torch.ones(outcome_size).to('cpu') / outcome_size\n",
    "group_num = 3\n",
    "KL_lambda = 0.2\n",
    "w = torch.randint(1, 11, (group_num, 1)).to('cpu')\n",
    "rm = torch.rand(group_num, outcome_size).to('cpu')\n",
    "rm = rm / rm.sum(-1, keepdim=True)\n",
    "theta = calculate_gt_maxmin(theta_init, KL_lambda, w, rm) # (outcome_size)\n",
    "maxmin_welfare = (w * rm * theta).sum(-1).min() - KL_lambda * kl_distance(theta, theta_init)\n",
    "\n",
    "for i in range(100000):\n",
    "    random_theta = torch.rand(outcome_size).to('cpu')\n",
    "    random_theta = random_theta / random_theta.sum()\n",
    "    random_welfare = (w * rm * random_theta).sum(-1).min() - KL_lambda * kl_distance(random_theta, theta_init)\n",
    "    if maxmin_welfare < random_welfare:\n",
    "        print(maxmin_welfare, random_welfare)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
