{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75858edf-817e-45d3-b9af-51123484dbd2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#!pip install torch torchvision torchaudio\n",
    "#!pip install pytorch-lightning\n",
    "#!pip install scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b32bfa2-58bd-4e4a-ae99-8d428bf92674",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn \n",
    "from torchvision.datasets import FashionMNIST, MNIST, EMNIST, CIFAR100\n",
    "from torchvision import transforms\n",
    "from pytorch_lightning import Trainer\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "import pytorch_lightning as pl\n",
    "from models import *\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from scipy.optimize import minimize\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "import matplotlib as mpl\n",
    "mpl.style.use(\"classic\")\n",
    "mpl.rcParams[\"figure.figsize\"] = [5, 3]\n",
    "\n",
    "mpl.rcParams[\"axes.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"figure.facecolor\"] = \"w\"\n",
    "mpl.rcParams[\"grid.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"lines.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"patch.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"xtick.major.size\"] = 3\n",
    "mpl.rcParams[\"ytick.major.size\"] = 3\n",
    "\n",
    "mpl.rcParams[\"pdf.fonttype\"] = 42\n",
    "mpl.rcParams[\"ps.fonttype\"] = 42\n",
    "mpl.rcParams[\"font.size\"] = 9\n",
    "mpl.rcParams[\"axes.titlesize\"] = \"medium\"\n",
    "mpl.rcParams[\"legend.fontsize\"] = \"medium\"\n",
    "\n",
    "\n",
    "import os\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "# Reproducibility\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('The available device is :', device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1942cb8-c098-4c10-bf9d-d4aa438b3a16",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dict_datasets = {'MNIST' : MNIST,\n",
    "                 'FashionMNIST': FashionMNIST,\n",
    "                 'EMNIST': EMNIST}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be30e9be-b9f2-4d91-8d45-3540d8630c4d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "name = 'MNIST'\n",
    "print('dataset used is :', name)\n",
    "dataset = dict_datasets[name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "833519a4-b3b8-484a-ae6f-7eb3a612838f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if name == 'EMNIST':\n",
    "    mnist_train = dataset(os.getcwd(), train=True, split = 'balanced', download=True, transform=transforms.ToTensor())\n",
    "    mnist_test = dataset(os.getcwd(), train=False, split = 'balanced', download=True, transform=transforms.ToTensor())\n",
    "\n",
    "else:\n",
    "    mnist_train = dataset(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
    "    mnist_test = dataset(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())\n",
    "        \n",
    "X_train, Y_train = torch.tensor(mnist_train.data).float().reshape(len(mnist_train.data), -1), torch.tensor(mnist_train.targets)\n",
    "X_test, Y_test = torch.tensor(mnist_test.data).float().reshape(len(mnist_test.data), -1), torch.tensor(mnist_test.targets)\n",
    "\n",
    "# rescaling the contexts\n",
    "X_train /= torch.norm(X_train, p = 2, dim = -1, keepdim=True)\n",
    "X_test /= torch.norm(X_test, p = 2, dim = -1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a7168fd-dcc4-4372-95f8-5cf3fbc3cd88",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# data subsampling to learn a logging Policy\n",
    "\n",
    "x0, X_log, y0, Y_log = train_test_split(X_train, Y_train, train_size = 0.05)\n",
    "\n",
    "N = len(X_log)\n",
    "N_test = len(X_test)\n",
    "\n",
    "print('Train : dimension of X is :', X_train.shape, 'dimension of Y is :', len(X_train))\n",
    "print('Test : dimension of X_test is :', X_test.shape, 'dimension of Y is :', N_test)\n",
    "context_dim = X_log.shape[1]\n",
    "num_actions = len(np.unique(Y_log))\n",
    "print('num_actions: ', num_actions)\n",
    "\n",
    "subsample_pt = TensorDataset(x0, y0)\n",
    "subsample_dataloader = DataLoader(subsample_pt, batch_size=128, shuffle=True)\n",
    "\n",
    "# create the logging split\n",
    "logging_split = TensorDataset(X_log, Y_log)\n",
    "logging_split_dataloader = DataLoader(logging_split, batch_size=128, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c26c6d8-c41e-4afb-a795-921b02621567",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Training a logging policy\n",
    "#etas = np.round(np.linspace(0, 1, 10), 2) # Inverse temperature parameter (the higher eta the better the performance of the logging policy)\n",
    "etas = np.round(np.linspace(-0.5, 0.5, 10), 2) # Inverse temperature parameter (the higher eta the better the performance of the logging policy)\n",
    "\n",
    "dict_results = defaultdict(list)\n",
    "#epochs_logging = 10 # logging policy is trained using 10 epochs\n",
    "#epochs = 20 # learning policies are trained using 20 epochs\n",
    "epochs_logging = 5 # logging policy is trained using 10 epochs\n",
    "epochs = 20 # learning policies are trained using 20 epochs\n",
    "\n",
    "#logging_policy = SupervisedPolicy(n_actions=num_actions, context_dim=context_dim, softmax = True, reg=1e-6, device = device)\n",
    "logging_policy = SupervisedPolicy(n_actions=num_actions, context_dim=context_dim, softmax = True, reg=0, device = device)\n",
    "trainer = Trainer(max_epochs=epochs_logging, gpus=1, checkpoint_callback=False, weights_summary=None, logger=None)\n",
    "trainer.fit(logging_policy, subsample_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58597fab-5939-44d8-8aaf-9b9ea85e71bd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i, eta in enumerate(etas):\n",
    "    logging_policy = logging_policy.to(device)\n",
    "    print('eta parameter : ', eta)  \n",
    "    logging_policy.alpha = eta\n",
    "    risk_logging = test_risk_exact_probit(X_test, Y_test, logging_policy)\n",
    "    print('The reward of the logging policy: ', -risk_logging)\n",
    "    \n",
    "    dict_results['eta'].append(eta)\n",
    "    dict_results['logging_reward'].append(-risk_logging)\n",
    "    \n",
    "    # Collect a bandit dataset\n",
    "    f, a, p, c = build_bandit_dataset(logging_split_dataloader, logging_policy, replay_count = 1)\n",
    "\n",
    "    print('max', p.max(dim = 0)[0].mean().item())\n",
    "    print('min', p.min(dim = 0)[0].mean().item())\n",
    "\n",
    "    bandit_train_posterior = TensorDataset(f, a, p, c)\n",
    "    bandit_train_posterior_dataloader = DataLoader(bandit_train_posterior, batch_size=128, shuffle=True)\n",
    "\n",
    "    mu_0 = eta * logging_policy.linear.weight.data\n",
    "    \n",
    "    taus = np.linspace(1e-5, 1, 10) #np.logspace(-10, 0, 20) #np.linspace(1e-10, 1, 10)\n",
    "    taus_bis = np.array([0.1, 0.5, 1, 10, 100, 500, 1000, 1e5, 1e7, 1e10]) # for shrinkage whose hyperparmater is between 0 and infitnity\n",
    "    \n",
    "    for j, tau in enumerate(taus):\n",
    "        \n",
    "        #####################################################################################################################\n",
    "        #####################################################################################################################\n",
    "        ######### Implicit Exploration\n",
    "        #######################################\n",
    "        model = IX(n_actions=num_actions, context_dim=context_dim, tau=tau, N=N,  loc_weight=mu_0, device=device)\n",
    "\n",
    "        trainer = Trainer(max_epochs=epochs, gpus=1, checkpoint_callback=False, weights_summary=None, logger=None)\n",
    "        trainer.fit(model, bandit_train_posterior_dataloader)\n",
    "\n",
    "        model = model.to(device)\n",
    "        with torch.no_grad():\n",
    "            risk_after_train = test_risk_exact_probit(X_test, Y_test, model)\n",
    "\n",
    "        print('Reward of IX after training  :', -risk_after_train)\n",
    "        print('################################################################################')\n",
    "\n",
    "        dict_results['ix_' + str(j)].append(-risk_after_train)    \n",
    "        \n",
    "        \n",
    "    \n",
    "        #####################################################################################################################\n",
    "        #####################################################################################################################\n",
    "        ######### Exponential Smoothing\n",
    "        #######################################\n",
    "        model = Smoothing(n_actions=num_actions, context_dim=context_dim, tau=tau, N=N,  loc_weight=mu_0, device=device)\n",
    "\n",
    "        trainer = Trainer(max_epochs=epochs, gpus=1, checkpoint_callback=False, weights_summary=None, logger=None)\n",
    "        trainer.fit(model, bandit_train_posterior_dataloader)\n",
    "\n",
    "        model = model.to(device)\n",
    "        with torch.no_grad():\n",
    "            risk_after_train = test_risk_exact_probit(X_test, Y_test, model)\n",
    "\n",
    "        print('Reward of Smoothing after training  :', -risk_after_train)\n",
    "        print('################################################################################')\n",
    "\n",
    "        dict_results['smoothing_' + str(j)].append(-risk_after_train)    \n",
    "        \n",
    "        \n",
    "\n",
    "        #####################################################################################################################\n",
    "        #####################################################################################################################\n",
    "        ######### Clipping\n",
    "        #######################################\n",
    "        model = Clipping(n_actions=num_actions, context_dim=context_dim, tau=tau, N=N,  loc_weight=mu_0, device=device)\n",
    "\n",
    "        trainer = Trainer(max_epochs=epochs, gpus=1, checkpoint_callback=False, weights_summary=None, logger=None)\n",
    "        trainer.fit(model, bandit_train_posterior_dataloader)\n",
    "\n",
    "        model = model.to(device)\n",
    "        with torch.no_grad():\n",
    "            risk_after_train = test_risk_exact_probit(X_test, Y_test, model)\n",
    "\n",
    "        print('Reward of clipping after training  :', -risk_after_train)\n",
    "        print('################################################################################')\n",
    "\n",
    "        dict_results['clipping_' + str(j)].append(-risk_after_train)\n",
    "        \n",
    "        \n",
    "\n",
    "        #####################################################################################################################\n",
    "        #####################################################################################################################\n",
    "        ######### Harmonic\n",
    "        #######################################\n",
    "        model = Harmonic(n_actions=num_actions, context_dim=context_dim, tau=tau, N=N,  loc_weight=mu_0, device=device)\n",
    "        \n",
    "        trainer = Trainer(max_epochs=epochs, gpus=1, checkpoint_callback=False, weights_summary=None, logger=None)\n",
    "        trainer.fit(model, bandit_train_posterior_dataloader)\n",
    "\n",
    "        model = model.to(device)\n",
    "        with torch.no_grad():\n",
    "            risk_after_train = test_risk_exact_probit(X_test, Y_test, model)\n",
    "\n",
    "        print('Reward of harmonic after training  :', -risk_after_train)\n",
    "        print('################################################################################')\n",
    "\n",
    "        dict_results['harmonic_' + str(j)].append(-risk_after_train)    \n",
    "        \n",
    "    print(dict_results)\n",
    "\n",
    "df = pd.DataFrame(dict_results)\n",
    "print(df)\n",
    "df.to_csv('results/bound_optimization/varying_' + name +'.csv', index = False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python Kernel (MOAB #58857)",
   "language": "python",
   "name": "python-kernel-58857"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
