{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eca56209",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f09d4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7db5a13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def truthful_utility_calculation(a_func,p_func,valuations):\n",
    "    #Given bids, find out the revenue of the auctioneer given paymt and alloc funcs\n",
    "    alloc_truthful = a_func(valuations)\n",
    "    paymt_truthful = p_func(valuations)\n",
    "\n",
    "    utility_truthful = torch.sum(alloc_truthful * (valuations - paymt_truthful * valuations),-1)\n",
    "    return utility_truthful # (batchsize x n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26072b9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def truthful_revenue_calculation(a_func,p_func,b):\n",
    "    #Given bids, find out the revenue of the auctioneer given paymt and alloc funcs\n",
    "    alloc_truthful = a_func(b)\n",
    "    paymt_truthful = p_func(b)\n",
    "\n",
    "    revenue_orig = torch.sum(torch.sum(alloc_truthful * b * paymt_truthful,dim = -1),dim = -1)\n",
    "    return revenue_orig # (batchsize, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3c7617",
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_bids_BNIC(bids,fake_bids):\n",
    "    #take in a bunch of bids [batch_size x n x m], and a bunch of lies [batch_size x n x m]\n",
    "    #output them put together [batch_size x n x n x m]\n",
    "    batch_size = bids.shape[0]\n",
    "    num_bidders = bids.shape[1]\n",
    "    num_items = bids.shape[2]\n",
    "    \n",
    "    bids_fake_bids = torch.zeros(batch_size,num_bidders,num_bidders,num_items).to(device)\n",
    "    for i in range(num_bidders):\n",
    "        bids_fake_bids[:,i,0:i,:] = bids[:,0:i,:]\n",
    "        bids_fake_bids[:,i,i,:] = fake_bids[:,i,:]\n",
    "        bids_fake_bids[:,i,i+1:,:] = bids[:,i+1:,:]\n",
    "\n",
    "    return bids_fake_bids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b73837",
   "metadata": {},
   "outputs": [],
   "source": [
    "def misreport_utility_calculation_BNIC(a_func,p_func,valuations, misreports):\n",
    "    #\n",
    "    #given a set of faked bids(already put together as repetitive n x n x m vector, as auctioneer is not lying everyone is seeing the same),\n",
    "\n",
    "    batch_size = valuations.shape[0]\n",
    "    num_bidders = valuations.shape[1]\n",
    "    num_items = valuations.shape[2]\n",
    "    \n",
    "    combined_bids = combine_bids_BNIC(valuations, misreports)\n",
    "    combined_bids_flat = torch.reshape(combined_bids,[num_bidders*batch_size,num_bidders,num_items])\n",
    "\n",
    "    alloc_flat = a_func(combined_bids_flat)\n",
    "    paymt_flat = p_func(combined_bids_flat)\n",
    "    \n",
    "\n",
    "    #after getting stuff out of NN, reprocess it\n",
    "    alloc_temp = torch.reshape(alloc_flat,[batch_size,num_bidders,num_bidders,num_items])\n",
    "    paymt_temp = torch.reshape(paymt_flat,[batch_size,num_bidders,num_bidders,1])\n",
    "\n",
    "    alloc = torch.zeros(batch_size,num_bidders,num_items).to(device)\n",
    "    paymt = torch.zeros(batch_size,num_bidders,1).to(device)\n",
    "\n",
    "    #only leave the i-th alloc and paymt for the i-th bidder, because other alloc and paymt are part of the \"lie\"\n",
    "    for i in range(num_bidders):\n",
    "        alloc[:,i,:] = alloc_temp[:,i,i,:]\n",
    "        paymt[:,i,:] = paymt_temp[:,i,i,:]\n",
    "\n",
    "\n",
    "    utilities = alloc * (valuations - paymt * misreports)\n",
    "    bidder_utilities = torch.sum(utilities, axis = -1)\n",
    "    \n",
    "    \n",
    "    return bidder_utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eea01eaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_max_regret_misreport_BNIC(a_func,p_func,m_func,valuations,min_num_iter, max_num_iter,lr,verbose = False):\n",
    "    #Take in bids, allocation and payment function, maximizes regret using Adam\n",
    "    batch_size = valuations.shape[0]\n",
    "    nBidders = valuations.shape[1]\n",
    "    nItems = valuations.shape[2]\n",
    "    \n",
    "    records = torch.ones(max_num_iter).to(device)\n",
    "\n",
    "    utility_truthful = truthful_utility_calculation(a_func,p_func,valuations)\n",
    "    utility_truthful_mean = torch.mean(utility_truthful)\n",
    "    \n",
    "    #m_func.reset()\n",
    "\n",
    "    opt = torch.optim.AdamW(m_func.parameters(),lr)\n",
    "\n",
    "    for j in range(max_num_iter):\n",
    "        opt.zero_grad()\n",
    "\n",
    "        misreport = m_func(valuations) * valuations\n",
    "        \n",
    "        utility_misreport = misreport_utility_calculation_BNIC(a_func,p_func, valuations, misreport)\n",
    "        utility_misreport_mean = torch.mean(utility_misreport)\n",
    "        \n",
    "        records[j] = utility_misreport_mean\n",
    "        q,r = divmod(j,min_num_iter)\n",
    "\n",
    "        if q > 0 and r == 0 and torch.mean(records[j-min_num_iter:j]) >= utility_misreport_mean:\n",
    "            #print(torch.mean(records[j-min_num_iter:j]).item())\n",
    "            break\n",
    "\n",
    "        rgt = 10 - utility_misreport_mean\n",
    "        rgt.backward()\n",
    "        opt.step()\n",
    "        \n",
    "        \n",
    "    misreport_to_ret = torch.reshape(misreport.detach().clone(),[batch_size,nBidders,nItems])\n",
    "    truthful_better = torch.reshape(torch.where(utility_misreport <  utility_truthful,1,0),(batch_size,nBidders,1))\n",
    "    truthful_better = torch.tile(truthful_better,(1,1,nItems))\n",
    "    misreport_better = torch.ones(truthful_better.shape).to(device) - truthful_better\n",
    "\n",
    "    to_ret = misreport_to_ret * misreport_better + valuations * truthful_better\n",
    "\n",
    "\n",
    "    if (verbose):\n",
    "        print(records[:j].detach())\n",
    "        \n",
    "    return to_ret\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87d5cc01",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(a_func,p_func,m_func,valuations):\n",
    "    #find revenue and regret of this auction\n",
    "    misreports = find_max_regret_misreport_BNIC(a_func,p_func,m_func,valuations,100, 300,0.00001)\n",
    "    \n",
    "    u_orig = truthful_utility_calculation(ANet,PNet,valuations)\n",
    "    u_new = misreport_utility_calculation_BNIC(ANet,PNet, valuations, misreports)\n",
    "        \n",
    "    rgt = torch.mean(u_new - u_orig)\n",
    "    rev = torch.mean(truthful_revenue_calculation(ANet,PNet,valuations))\n",
    "\n",
    "    #print(rgt)\n",
    "    \n",
    "    return rev.item(),rgt.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c82b28d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_train_or_test_set(batch_size,n,m,prob_single):\n",
    "    prob = torch.tile(prob_single,(batch_size,1))\n",
    "    participation = torch.rand(batch_size,n).to(device)\n",
    "    participation_bin = torch.where(participation > prob, 0, 1)\n",
    "    participation_bin = torch.reshape(participation_bin, (batch_size,n,1))\n",
    "    participation_bin = torch.tile(participation_bin,(1,1,m))\n",
    "    \n",
    "    return torch.rand([batch_size,n,m]).to(device) * participation_bin"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
