{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2cd448be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import collections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a226beda",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a4fe3c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, dimInput, dimOutput, nLayers, width):\n",
    "        super().__init__()\n",
    "        self.layers = collections.OrderedDict()\n",
    "        \n",
    "        self.layers[\"fc1\"] = nn.Linear(dimInput, width).to(device)\n",
    "        \n",
    "        for i in range(2,nLayers):\n",
    "            self.layers[\"tanh\" + str(i - 1)] = nn.Tanh().to(device)\n",
    "            self.layers[\"fc\" + str(i)] = nn.Linear(width, width).to(device)\n",
    "\n",
    "        self.layers[\"tanh\" + str(nLayers - 1)] = nn.Tanh().to(device)\n",
    "        self.layers[\"fc\" + str(nLayers)] = nn.Linear(width, dimOutput).to(device)\n",
    "        \n",
    "        self.model = nn.Sequential(self.layers).to(device)\n",
    "\n",
    "    def forward(self, input):\n",
    "        input = input.to(device)\n",
    "        return self.model(input)\n",
    "    \n",
    "\n",
    "class AllocationNet(nn.Module):\n",
    "    #takes in bids (n x m), returns an allocation for each item (n x m)\n",
    "    \n",
    "    #first decide whether to allocate the item or not\n",
    "    \n",
    "    def __init__(self,n,m,nLayers,width):\n",
    "        super().__init__()\n",
    "        self.mlp1 = MLP(n*m,m,nLayers,width)\n",
    "        self.mlp2 = MLP(n*m,n*m,nLayers,width)\n",
    "        self.n = n\n",
    "        self.m = m\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        x = torch.flatten(x,start_dim=-2,end_dim=-1)\n",
    "        \n",
    "        alloc_or_not = torch.reshape(torch.sigmoid(self.mlp1(x)),[batch_size,1,self.m])\n",
    "        allocate_to_who = torch.softmax(torch.reshape(self.mlp2(x),[batch_size,self.n,self.m]),dim = -2)\n",
    "        \n",
    "        return alloc_or_not * allocate_to_who \n",
    "    \n",
    "    def reset(self):\n",
    "        for layer in self.fcs:\n",
    "            layer.reset_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deb1aadf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Initializes a payment function\n",
    "class PaymentNet(nn.Module):\n",
    "    def __init__(self,n,m,nLayers,width):\n",
    "        super().__init__()\n",
    "        self.mlp = MLP(n*m,n,nLayers,width)\n",
    "        self.n = n\n",
    "        self.m = m\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        x = torch.flatten(x,start_dim=-2,end_dim=-1)\n",
    "\n",
    "        paymt = torch.reshape(torch.sigmoid(self.mlp(x)),[batch_size,self.n,1])\n",
    "        \n",
    "        return paymt\n",
    "\n",
    "\n",
    "    def reset(self):\n",
    "        for layer in self.fcs:\n",
    "            layer.reset_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee163765",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MisreportNetBNIC(nn.Module):\n",
    "    #takes in bids (n x m), returns misreport for each bid as a ratio (n x m)\n",
    "    def __init__(self,n,m,nLayers,width):\n",
    "        super().__init__()\n",
    "        self.mlp = MLP(n*m,n *m,nLayers,width)\n",
    "        self.n = n\n",
    "        self.m = m\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        x = torch.flatten(x,start_dim=-2,end_dim=-1)\n",
    "\n",
    "        mreprt = torch.reshape(torch.sigmoid(self.mlp(x)),[batch_size,self.n,self.m])\n",
    "        \n",
    "        return mreprt\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
