{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    SEED\n",
    "except NameError:\n",
    "    SEED = 0\n",
    "else:\n",
    "    SEED += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "import numpy as np\n",
    "import torch\n",
    "from exp_utils import get_data\n",
    "import torch.nn as nn\n",
    "from monotonic import MonotonicLinear\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import random\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)\n",
    "random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_mono = np.array([0, -1, -1, -1, 0, 0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_input(inputs):\n",
    "    return inputs[:, np.where(1-mask_mono)].squeeze(), \\\n",
    "        inputs[:, np.where(mask_mono)].squeeze() * torch.tensor(mask_mono[np.where(mask_mono)][None,:], dtype=torch.float32).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Upload skipped, file /home/alberto_sinigaglia/jupyter_notebooks/inverseCDF/data/train_auto.csv exists.\n",
      "Upload skipped, file /home/alberto_sinigaglia/jupyter_notebooks/inverseCDF/data/test_auto.csv exists.\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from torch import Tensor\n",
    "\n",
    "train_df, val_df = get_data(\"auto\")\n",
    "X_train = torch.tensor(train_df.loc[:, train_df.columns != 'ground_truth'].values).to(device)\n",
    "X_val = torch.tensor(val_df.loc[:, val_df.columns != 'ground_truth'].values).to(device)\n",
    "y_train = torch.tensor(train_df.loc[:, train_df.columns == 'ground_truth'].values).to(device)\n",
    "y_val = torch.tensor(val_df.loc[:, val_df.columns == 'ground_truth'].values).to(device)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(TensorDataset(X_train, y_train), batch_size=8, shuffle=True, drop_last=True)\n",
    "val_loader = torch.utils.data.DataLoader(TensorDataset(X_val, y_val), batch_size=8, shuffle=True, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):\n",
    "    losses = []\n",
    "    val_losses = []\n",
    "    best_val_acc = 0.\n",
    "    for epoch in trange(num_epochs):\n",
    "        model.train()\n",
    "        total = 0\n",
    "        losses_buffer = []\n",
    "        for inputs, labels in train_loader:\n",
    "            inputs_free, inputs_mono = split_input(inputs)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs_free.float(), inputs_mono.float())\n",
    "            loss = criterion(outputs, labels.float())\n",
    "            losses_buffer.append(loss)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            total += labels.size(0)\n",
    "        losses.append(np.mean([el.detach().cpu() for el in losses_buffer]))\n",
    "        \n",
    "        val_loss = validate_model(model, val_loader, criterion)\n",
    "        val_losses.append(val_loss)\n",
    "    \n",
    "    return losses, val_losses\n",
    "\n",
    "def validate_model(model, val_loader, criterion):\n",
    "    model.eval()\n",
    "    val_loss = []\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in val_loader:\n",
    "            inputs_free, inputs_mono = split_input(inputs)\n",
    "            outputs = model(inputs_free.float(), inputs_mono.float())\n",
    "            loss = criterion(outputs, labels.float())\n",
    "            val_loss += [loss.item()]\n",
    "    return np.mean(val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReLUMonoModel(torch.nn.Module):\n",
    "    def __init__(self, input_size_mono, num_layers_mono, num_layers_pre_mono, num_neurons_mono, num_neurons_pre_mono) -> None:\n",
    "        super().__init__()\n",
    "        self.pre_mono = torch.nn.ModuleList([torch.nn.LazyLinear(num_neurons_pre_mono) for _ in range(num_layers_pre_mono)])\n",
    "        self.mono = torch.nn.ModuleList(\n",
    "            [\n",
    "                MonotonicLinear(input_size_mono + num_neurons_pre_mono, num_neurons_mono, pre_activation=nn.Identity()),\n",
    "                *[MonotonicLinear(num_neurons_mono, num_neurons_mono, pre_activation=nn.ReLU()) for _ in range(num_layers_mono)],\n",
    "                MonotonicLinear(num_neurons_mono, 1, pre_activation=nn.ReLU()),\n",
    "            ]\n",
    "        )\n",
    "    def forward(self, x, x_mono):\n",
    "        for layer in self.pre_mono:\n",
    "            x = torch.nn.functional.relu(layer(x))\n",
    "        \n",
    "        x = torch.cat((x, x_mono), dim=-1)\n",
    "        for layer in self.mono:\n",
    "            x = layer(x)\n",
    "        \n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ReLUMonoModel((mask_mono!=0).sum(),3,3, 16,16).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:04<00:00,  4.75it/s]\n"
     ]
    }
   ],
   "source": [
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 1e-3)\n",
    "losses, val_losses = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(7.217586716016133)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.min(val_losses)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
