{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    SEED\n",
    "except NameError:\n",
    "    SEED = 0\n",
    "else:\n",
    "    SEED += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "import numpy as np\n",
    "import torch\n",
    "from exp_utils import get_data\n",
    "import torch.nn as nn\n",
    "from monotonic import MonotonicLinear\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import random\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)\n",
    "random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_mono = np.array([1 if i in range(50, 54) or i in range(55, 59) else 0 for i in range(276)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_input(inputs):\n",
    "    return inputs[:, np.where(1-mask_mono)].squeeze(), \\\n",
    "        inputs[:, np.where(mask_mono)].squeeze() * torch.tensor(mask_mono[np.where(mask_mono)][None,:], dtype=torch.float32).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Upload skipped, file /home/alberto_sinigaglia/jupyter_notebooks/inverseCDF/data/train_blog.csv exists.\n",
      "Upload skipped, file /home/alberto_sinigaglia/jupyter_notebooks/inverseCDF/data/test_blog.csv exists.\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from torch import Tensor\n",
    "\n",
    "train_df, val_df = get_data(\"blog\")\n",
    "X_train = torch.tensor(train_df.loc[:, train_df.columns != 'ground_truth'].values).to(device)\n",
    "X_val = torch.tensor(val_df.loc[:, val_df.columns != 'ground_truth'].values).to(device)\n",
    "y_train = torch.tensor(train_df.loc[:, train_df.columns == 'ground_truth'].values).to(device)\n",
    "y_val = torch.tensor(val_df.loc[:, val_df.columns == 'ground_truth'].values).to(device)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(TensorDataset(X_train, y_train), batch_size=256, shuffle=True, drop_last=True)\n",
    "val_loader = torch.utils.data.DataLoader(TensorDataset(X_val, y_val), batch_size=256, shuffle=True, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):\n",
    "    losses = []\n",
    "    val_losses = []\n",
    "    best_val_acc = 0.\n",
    "    for epoch in trange(num_epochs):\n",
    "        model.train()\n",
    "        total = 0\n",
    "        losses_buffer = []\n",
    "        for inputs, labels in train_loader:\n",
    "            inputs_free, inputs_mono = split_input(inputs)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs_free.float(), inputs_mono.float())\n",
    "            loss = criterion(outputs, labels.float())\n",
    "            losses_buffer.append(loss)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            total += labels.size(0)\n",
    "        losses.append(np.mean([el.detach().cpu() for el in losses_buffer]))\n",
    "        \n",
    "        val_loss = validate_model(model, val_loader, criterion)\n",
    "        val_losses.append(val_loss)\n",
    "    \n",
    "    return losses, val_losses\n",
    "\n",
    "def validate_model(model, val_loader, criterion):\n",
    "    model.eval()\n",
    "    val_loss = []\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in val_loader:\n",
    "            inputs_free, inputs_mono = split_input(inputs)\n",
    "            outputs = model(inputs_free.float(), inputs_mono.float())\n",
    "            loss = criterion(outputs, labels.float())\n",
    "            val_loss += [loss.item()]\n",
    "    return np.mean(val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReLUMonoModel(torch.nn.Module):\n",
    "    def __init__(self, input_size_mono, num_layers_mono, num_layers_pre_mono, num_neurons_mono, num_neurons_pre_mono) -> None:\n",
    "        super().__init__()\n",
    "        self.pre_mono = torch.nn.ModuleList([torch.nn.LazyLinear(num_neurons_pre_mono) for _ in range(num_layers_pre_mono)])\n",
    "        self.mono = torch.nn.ModuleList(\n",
    "            [\n",
    "                MonotonicLinear(input_size_mono + num_neurons_pre_mono, num_neurons_mono, pre_activation=nn.Identity()),\n",
    "                *[MonotonicLinear(num_neurons_mono, num_neurons_mono, pre_activation=nn.ReLU()) for _ in range(num_layers_mono)],\n",
    "                MonotonicLinear(num_neurons_mono, 1, pre_activation=nn.ReLU()),\n",
    "            ]\n",
    "        )\n",
    "    def forward(self, x, x_mono):\n",
    "        for layer in self.pre_mono:\n",
    "            x = torch.nn.functional.relu(layer(x))\n",
    "        \n",
    "        x = torch.cat((x, x_mono), dim=-1)\n",
    "        for layer in self.mono:\n",
    "            x = layer(x)\n",
    "        \n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ReLUMonoModel((mask_mono!=0).sum(),3,3, 12,4).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:51<00:00,  1.04s/it]\n"
     ]
    }
   ],
   "source": [
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 1e-2)\n",
    "losses, val_losses = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.15865363907864322)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sqrt(np.min(val_losses))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
