{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import wandb\n",
    "import datetime\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from os.path import exists\n",
    "\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, loader, optimizer, loss_fn, device):\n",
    "\n",
    "    model.train()\n",
    "    cumu_loss = 0\n",
    "    for i, (x_batch, y_batch) in enumerate(loader):\n",
    "       \n",
    "        # Transfer tensors to GPU\n",
    "        x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
    "        \n",
    "        # Zero parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # Forward pass\n",
    "        y_hat_batch = model(x_batch)\n",
    "\n",
    "        # Compute loss\n",
    "        loss = loss_fn(y_hat_batch, y_batch)\n",
    "\n",
    "        cumu_loss += loss.item() \n",
    "\n",
    "        wandb.log({\n",
    "            'Train batch loss'  : loss.item(),\n",
    "        })\n",
    "\n",
    "        # Backward pass\n",
    "        loss.backward()\n",
    "        \n",
    "        # Step\n",
    "        optimizer.step()\n",
    "\n",
    "        # Store\n",
    "        cumu_loss += loss.item()\n",
    "    \n",
    "    return cumu_loss / len(loader)\n",
    "\n",
    "def validate(model, loader, loss_fn, device):\n",
    "\n",
    "    with torch.no_grad():\n",
    "        cumu_loss = 0\n",
    "\n",
    "        for i, (x_batch, y_batch) in enumerate(loader):\n",
    "            # Transfer tensors to GPU\n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
    "\n",
    "            # Forward pass\n",
    "            y_hat_batch = model(x_batch)\n",
    "\n",
    "            # Compute loss\n",
    "            loss = loss_fn(y_hat_batch, y_batch)\n",
    "        \n",
    "            # Store batch errors\n",
    "            cumu_loss += loss.item() \n",
    "\n",
    "            wandb.log({\n",
    "                'Val batch loss' : loss.item()\n",
    "            })\n",
    "        \n",
    "        return cumu_loss / len(loader)\n",
    "\n",
    "# Data loader\n",
    "class ShapDataset(Dataset):\n",
    "    def __init__(self, Feat, Shap):\n",
    "        self.Feat = Feat\n",
    "        self.Shap = Shap\n",
    "\n",
    "    def __len__(self):\n",
    "        ''' get total number of samples in dataset '''\n",
    "        return self.Feat.shape[0]\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        ''' get 1D tensor of weights and respective payoffs'''\n",
    "        return (\n",
    "            self.Feat[index, :].float(),\n",
    "            self.Shap[index, :].float()\n",
    "        )\n",
    "\n",
    "class MLP_toshap(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, n_layers, drop_prob):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        for i in range(n_layers-1):\n",
    "            layers += [\n",
    "                nn.Linear(input_size, hidden_size),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Dropout(drop_prob)\n",
    "            ]\n",
    "            input_size = hidden_size\n",
    "\n",
    "        # Add output layer\n",
    "        layers += [\n",
    "            nn.Linear(input_size, output_size),\n",
    "            ]\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.layers(x)\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(config=None):\n",
    "\n",
    "    with wandb.init(config=config):\n",
    "        config = wandb.config\n",
    "\n",
    "        train_feats, val_feats, train_shap, val_shap = train_test_split(\n",
    "            feats_train, \n",
    "            shap_train, \n",
    "            train_size=0.8, \n",
    "            random_state=42\n",
    "        )\n",
    "\n",
    "        train_set = ShapDataset(train_feats, train_shap)\n",
    "        val_set = ShapDataset(val_feats, val_shap)\n",
    "\n",
    "        train_loader = DataLoader(\n",
    "            train_set, \n",
    "            batch_size=config.batch_size, \n",
    "            shuffle=True, \n",
    "            drop_last=True\n",
    "        )\n",
    "        \n",
    "        val_loader = DataLoader(\n",
    "            val_set, \n",
    "            batch_size=config.batch_size, \n",
    "            shuffle=True, \n",
    "            drop_last=True,\n",
    "        )\n",
    "\n",
    "        # Build model\n",
    "        model = MLP_toshap(\n",
    "            input_size=NUM_FEATS,\n",
    "            hidden_size=config.hidden_size,\n",
    "            n_layers=config.n_layers,\n",
    "            output_size=NUM_FEATS,\n",
    "            drop_prob=config.drop_prob,\n",
    "            ).to(DEVICE)\n",
    "\n",
    "        # Set loss and optimizer\n",
    "        loss_fn = nn.MSELoss()\n",
    "        optimizer = torch.optim.Adam(\n",
    "            model.parameters(), \n",
    "            lr=config.lr,\n",
    "        )\n",
    "\n",
    "        # Train\n",
    "        INIT_PATIENCE = 5        \n",
    "        BEST_VAL_LOSS = 20 # Arbitrary large value\n",
    "\n",
    "        start_train = time.time()\n",
    "        for epoch in range(NUM_EPOCHS):\n",
    "            avg_train_loss = train(model, train_loader, optimizer, loss_fn, DEVICE)\n",
    "            avg_val_loss =  validate(model, val_loader, loss_fn, DEVICE)\n",
    "\n",
    "            wandb.log({\n",
    "                'Train loss'    : avg_train_loss, \n",
    "                'Val   loss'    : avg_val_loss, \n",
    "                'Epoch'         : epoch\n",
    "            })\n",
    "            PATIENCE = INIT_PATIENCE\n",
    "            \n",
    "            if avg_val_loss < BEST_VAL_LOSS:\n",
    "                BEST_VAL_LOSS = avg_val_loss\n",
    "                PATIENCE = INIT_PATIENCE\n",
    "            else:\n",
    "                PATIENCE -= 1\n",
    "\n",
    "            if PATIENCE == 0:\n",
    "                break\n",
    "         \n",
    "        end_train = time.time()\n",
    "\n",
    "        checkpoint_file = f'{CHECKPOINT_PATH}_best_model_trainsize{TRAIN_SIZE}.pth'\n",
    "\n",
    "        if exists(checkpoint_file):\n",
    "            prev_best_val_loss = torch.load(checkpoint_file)['val_loss']\n",
    "        else:\n",
    "            prev_best_val_loss = np.inf\n",
    "\n",
    "        if avg_val_loss < prev_best_val_loss: \n",
    "            print(f'Saving model with val loss: {avg_val_loss:.4f} | prev: {prev_best_val_loss:.4f}')\n",
    "\n",
    "            torch.save(\n",
    "                obj = {\n",
    "                    'sweep_name'            : wandb.run.name,\n",
    "                    'epochs'                : NUM_EPOCHS,\n",
    "                    'n_layers'              : config.n_layers, # Save config that includes model specifics\n",
    "                    'hidden_size'           : config.hidden_size, \n",
    "                    'batch_size'            : config.batch_size,\n",
    "                    'lr'                    : config.lr,\n",
    "                    'drop_prob'             : config.drop_prob,\n",
    "                    'num_params'            : count_parameters(model),\n",
    "                    'model_state_dict'      : model.state_dict(), \n",
    "                    'optimizer_state_dict'  : optimizer.state_dict(), \n",
    "                    'train_loss'            : avg_train_loss,\n",
    "                    'val_loss'              : avg_val_loss,\n",
    "                    'train_time'            : end_train - start_train,\n",
    "                },\n",
    "                f = checkpoint_file,\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SWEEPS = 20\n",
    "NUM_EPOCHS = 5000\n",
    "NUM_FEATS = 13\n",
    "TRAIN_SET_SIZES = [1000, 1500, 2000, 2500, 3000, 3500]\n",
    "\n",
    "BASE = '< BASE_PATH >' \n",
    "\n",
    "AVAILABLE_CUDA = 1\n",
    "DEVICE = torch.device(f'cuda:{AVAILABLE_CUDA}' if torch.cuda.is_available() else 'cpu')\n",
    "CHECKPOINT_PATH = f'{BASE}/HousingExperiment/Aug1Checkpoints/'\n",
    "\n",
    "# Load dataset\n",
    "feats_complete = torch.load(f'{BASE}Data/feats_housing_f{NUM_FEATS}.pt')\n",
    "shapvals_complete = torch.load(f'{BASE}Data/feats_housing_f{NUM_FEATS}.pt')\n",
    "\n",
    "# Split to training/validation set and test (store rest for testing)\n",
    "TRAIN_SIZE = TRAIN_SET_SIZES[0]\n",
    "\n",
    "# Data is already shuffled, take the first N samples for train/validation\n",
    "feats_train = feats_complete[:TRAIN_SIZE, :]\n",
    "shap_train = feats_complete[:TRAIN_SIZE, :]\n",
    "\n",
    "# Login, start the sweep\n",
    "wandb.login()\n",
    "\n",
    "sweep_config = {\n",
    "    'name': f'Train Shap Models on {TRAIN_SIZE}/{feats_complete.shape[0]} samples',\n",
    "    'method': 'grid',\n",
    "    'metric': {'name': 'loss', 'goal': 'minimize'},\n",
    "    }\n",
    "\n",
    "params_dict = {\n",
    "    'batch_size': {\n",
    "        'values': [128]\n",
    "        },\n",
    "    'hidden_size': {\n",
    "        'values': [16, 32, 64, 96, 192, 384]\n",
    "        },\n",
    "    'n_layers' : {\n",
    "        'values' : [2, 3, 4]\n",
    "    },\n",
    "    'drop_prob': {\n",
    "        'values': [0.1]\n",
    "        },\n",
    "    'lr': {\n",
    "        'values': [1e-4]\n",
    "        },\n",
    "} \n",
    "sweep_config['parameters'] = params_dict\n",
    "\n",
    "# Set the sweep id\n",
    "sweep_id = wandb.sweep(sweep_config, project=f'Shap Model {str(datetime.date.today())}')\n",
    "\n",
    "# Start the sweeps!\n",
    "wandb.agent(sweep_id, main, count=NUM_SWEEPS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('rebuttal')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "8b4c61b0c0dfa50a12dd4948992e6e1a04659874706c6edcf781c314bdb1c462"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
