{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e60ada53",
   "metadata": {},
   "source": [
    "### This notebook replicates results for the Covertype simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bf638bc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Sequence,Iterable, Dict, Any\n",
    "\n",
    "from coba.contexts import CobaContext\n",
    "from coba.random import CobaRandom\n",
    "from coba.environments import Environments, EnvironmentFilter, OpenmlSource, SimulatedInteraction, Take\n",
    "from coba.experiments import Experiment, EvaluationTask\n",
    "from coba.pipes import Pipes\n",
    "\n",
    "from coba.encodings import OneHotEncoder\n",
    "from coba.random import CobaRandom\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import FormatStrFormatter\n",
    "\n",
    "import pandas as pd\n",
    "from operator import truediv, sub, gt\n",
    "from itertools import chain, repeat, accumulate, groupby, count\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "28926c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "class InteractionGroundedLearner:\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "#this turns a supervised problem into our IGL problem with normal/bizzaro users and good/bad feedback words\n",
    "class ToRecommenderIGL(EnvironmentFilter):\n",
    "\n",
    "    def __init__(self, n_users: int, n_normal:int, n_words:int, n_good:int, seed:int) -> None:\n",
    "        self._n_users  = n_users\n",
    "        self._n_normal = n_normal\n",
    "        self._n_words  = n_words\n",
    "        self._n_good   = n_good\n",
    "        self._seed     = seed\n",
    "        self.userids   = list(range(self._n_users))\n",
    "        self.normalids = self.userids[:self._n_normal]\n",
    "        self.wordids = list(range(self._n_words))\n",
    "        self.good_words = self.wordids[:self._n_good]\n",
    "        self.bad_words = self.wordids[self._n_good:]\n",
    "\n",
    "    @property\n",
    "    def params(self) -> Dict[str, Any]:\n",
    "        return {\n",
    "            \"n_users\" : self._n_users,\n",
    "            \"n_normal\": self._n_normal,\n",
    "            \"n_good\"  : self._n_good,\n",
    "            \"n_words\" : self._n_words,\n",
    "            \"igl_seed\": self._seed\n",
    "        }\n",
    "\n",
    "    def filter(self, interactions: Iterable[SimulatedInteraction]) -> Iterable[SimulatedInteraction]:\n",
    "        rng = CobaRandom(self._seed)\n",
    "\n",
    "        for interaction in interactions:\n",
    "            userid = rng.choice(self.userids)\n",
    "            good, bad = (self.good_words, self.bad_words) if userid in self.normalids else (self.bad_words, self.good_words)\n",
    "            words = [ (rng.choice(good),) if r==1 else (rng.choice(bad),) for r in interaction.rewards ]\n",
    "            kwargs = { \"userid\":userid, \"feedbacks\":words, \"isnormal\": userid in self.normalids }\n",
    "            yield SimulatedInteraction((userid,)+interaction.context, interaction.actions, interaction.rewards, **kwargs)\n",
    "\n",
    "#this takes care of actually evaluating our IGL Learner against an IGL environment\n",
    "class EvalRecommenderIGL(EvaluationTask):\n",
    "\n",
    "    def __init__(self, seed:int = 1) -> None:\n",
    "        self._seed = seed\n",
    "        self._savelearner = None\n",
    "\n",
    "    def process(self, learner: InteractionGroundedLearner, interactions: Iterable[SimulatedInteraction]) -> Iterable[Dict[Any,Any]]:\n",
    "        self._savelearner = learner\n",
    "        \n",
    "        rng = CobaRandom(self._seed)\n",
    "\n",
    "        for interaction in interactions:\n",
    "\n",
    "            context   = interaction.context\n",
    "            actions   = interaction.actions\n",
    "            rewards   = interaction.rewards\n",
    "            feedbacks = interaction.kwargs['feedbacks']\n",
    "            userid    = interaction.kwargs['userid']\n",
    "            isnormal  = interaction.kwargs['isnormal']\n",
    "\n",
    "            probabilities = learner.predict(context=context, actions=interaction.actions)\n",
    "\n",
    "            action       = rng.choice(actions, probabilities)\n",
    "            action_index = actions.index(action)\n",
    "            reward       = rewards[action_index]\n",
    "            feedback     = feedbacks[action_index]\n",
    "            probability  = probabilities[action_index]\n",
    "\n",
    "            stuff = learner.learn(context=context, actions=actions, action=action, feedback=feedback, probability=probability)\n",
    "\n",
    "            yield { \"reward\": reward, \"feedback\": feedback, \"userid\": userid, \"isnormal\": isnormal, \"probability\": probability } | stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7b95b39d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class CauchyMulticlass(torch.nn.Module):\n",
    "    def __init__(self, *, dobs, n_rff, sigma, n_classes, device='cpu'):\n",
    "        from math import pi, sqrt\n",
    "        \n",
    "        super().__init__()\n",
    "        assert n_rff == int(n_rff) and n_rff > 0\n",
    "        assert 0 < sigma\n",
    "        \n",
    "        self.rffW = torch.nn.Parameter(torch.empty(dobs, n_rff).cauchy_(sigma = sigma).to(device), \n",
    "                                       requires_grad=False)\n",
    "        self.rffb = torch.nn.Parameter((2 * pi * torch.rand(n_rff)).to(device),\n",
    "                                       requires_grad=False)\n",
    "        self.sqrtrff = torch.nn.Parameter(torch.Tensor([sqrt(n_rff)]).to(device), \n",
    "                                          requires_grad=False)\n",
    "        self.linear = torch.nn.Linear(in_features=n_rff, out_features=n_classes, device=device)\n",
    "        self.softmax = torch.nn.Softmax(dim=1)\n",
    "               \n",
    "    def logits(self, Xs):\n",
    "        with torch.no_grad():\n",
    "            rff = (torch.matmul(Xs, self.rffW) + self.rffb).cos() / self.sqrtrff\n",
    "        return self.linear(rff)\n",
    "\n",
    "    def probs(self, Xs):\n",
    "        return self.softmax(self.logits(Xs))\n",
    "    \n",
    "class BilinearCauchyMulticlass(torch.nn.Module):\n",
    "    def __init__(self, *, n_users, n_feedbacks, d_embed, **kwargs):\n",
    "        super().__init__()\n",
    "        \n",
    "        assert n_users == int(n_users) and n_users > 0\n",
    "        assert n_feedbacks == int(n_feedbacks) and n_feedbacks > 0\n",
    "        \n",
    "        self.cauchy = CauchyMulticlass(**kwargs)\n",
    "        self.embedU = torch.nn.Embedding(n_users, d_embed)\n",
    "        self.embedF = torch.nn.Embedding(n_feedbacks, d_embed)\n",
    "        self.softmax = torch.nn.Softmax(dim=1)\n",
    "    \n",
    "    def logits(self, Xs):\n",
    "        cauchylogits = self.cauchy.logits(Xs[:,2:])\n",
    "        Fs = self.embedF(Xs[:,0].long())\n",
    "        Us = self.embedU(Xs[:,1].long())\n",
    "        UsdotFs = torch.inner(Us, Fs)\n",
    "        return torch.mul(UsdotFs, cauchylogits)\n",
    "    \n",
    "    def probs(self, Xs):\n",
    "        return self.softmax(self.logits(Xs))\n",
    "    \n",
    "class PythonIGL(InteractionGroundedLearner):\n",
    "    def __init__(self, *, n_users, n_feedbacks, d_embed, sigma, n_rff, lr, epsilon, epsilon_t0):\n",
    "        super().__init__()\n",
    "        \n",
    "        assert n_users == int(n_users) and n_users > 0 \n",
    "        assert n_feedbacks == int(n_feedbacks) and n_feedbacks > 0 \n",
    "        assert d_embed == int(d_embed) and d_embed > 0\n",
    "        assert 0 < sigma\n",
    "        assert n_rff == int(n_rff) and n_rff > 0\n",
    "        assert 0 < lr\n",
    "        \n",
    "        self.regressor = None\n",
    "        self.epsilon = epsilon\n",
    "        self.epsilon_t0 = epsilon_t0\n",
    "        self.t = epsilon_t0\n",
    "        self.ikmodel = None\n",
    "        self.sigma = sigma\n",
    "        self.n_rff = n_rff\n",
    "        self.lr = lr\n",
    "        self.n_users = n_users\n",
    "        self.n_feedbacks = n_feedbacks\n",
    "        self.d_embed = d_embed\n",
    "    \n",
    "    def getUserEmbedding(self, userids):\n",
    "        return self.ikmodel.embedU(userids)\n",
    "    \n",
    "    def getFeedbackEmbedding(self, feedbackids):\n",
    "        return self.ikmodel.embedF(feedbackids)\n",
    "        \n",
    "    def __setupRegressor(self, *, context, actions):\n",
    "        if self.regressor is None:\n",
    "            self.regressor = CauchyMulticlass(dobs=len(context), sigma=self.sigma, \n",
    "                                              n_rff=self.n_rff, n_classes=len(actions))\n",
    "            self.regloss = torch.nn.BCELoss()\n",
    "            self.regopt = torch.optim.Adam((p for p in self.regressor.parameters() if p.requires_grad), lr=self.lr)\n",
    "        \n",
    "    def __myeps(self):\n",
    "        return self.epsilon * (self.epsilon_t0 / self.t)**(1/3)\n",
    "        \n",
    "    def predict(self, *, context, actions) -> Sequence[float]:\n",
    "        context = tuple([ v for c in context for v in ( c if isinstance(c, tuple) else (c,) ) ])\n",
    "        self.__setupRegressor(context=context, actions=actions)\n",
    "        with torch.no_grad():\n",
    "            logits = self.regressor.logits(torch.Tensor(context).unsqueeze(0))\n",
    "            predict = torch.argmax(logits)\n",
    "            myeps = self.__myeps()\n",
    "            paction = [myeps/len(actions)]*len(actions)\n",
    "            paction[predict] += 1 - myeps\n",
    "            self.t += 1\n",
    "            return paction\n",
    "        \n",
    "    def __setupInverseKinematics(self, *, context, actions, feedback):\n",
    "        if self.ikmodel is None:\n",
    "            self.ikmodel = BilinearCauchyMulticlass(n_users=self.n_users, n_feedbacks=self.n_feedbacks, d_embed=self.d_embed, \n",
    "                                                    dobs=len(feedback)+len(context)-2, sigma=self.sigma, \n",
    "                                                    n_rff=self.n_rff, n_classes=len(actions))\n",
    "            self.ikloss = torch.nn.CrossEntropyLoss()\n",
    "            self.ikopt = torch.optim.Adam((p for p in self.ikmodel.parameters() if p.requires_grad), lr=self.lr)\n",
    "    \n",
    "    def learn(self, *, context, actions, action, feedback, probability):\n",
    "        aindex = actions.index(action)\n",
    "        actioniw = 1 / (len(actions)*probability)\n",
    "        context = tuple([ v for c in context for v in ( c if isinstance(c, tuple) else (c,) ) ])\n",
    "        self.__setupInverseKinematics(context=context, actions=actions, feedback=feedback)\n",
    "        with torch.no_grad():\n",
    "            logits = self.ikmodel.logits(torch.Tensor(feedback+context).unsqueeze(0))\n",
    "            predict = torch.argmax(logits)\n",
    "            ikacc = 1 if predict.item() == aindex else 0\n",
    "            paction = self.ikmodel.softmax(logits)[0]\n",
    "            fakereward = 1 if paction[aindex].item() * len(actions) > 2 else 0            \n",
    "        \n",
    "        self.ikopt.zero_grad()\n",
    "        loss = actioniw * self.ikloss(self.ikmodel.logits(torch.Tensor(feedback+context).unsqueeze(0)), torch.Tensor([aindex]).long())\n",
    "        loss.backward()\n",
    "        self.ikopt.step()\n",
    "        \n",
    "        self.regopt.zero_grad()\n",
    "        regpred = self.regressor.probs(torch.Tensor(context).unsqueeze(0))\n",
    "        regloss = actioniw * self.regloss(regpred[:,aindex], torch.Tensor([fakereward]))\n",
    "        regloss.backward()\n",
    "        self.regopt.step()\n",
    "        \n",
    "        return {'ikloss': loss.item(), 'ikacc': ikacc, 'fakereward': fakereward, 'epsilon': self.__myeps(), 'regloss': regloss.item() }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0dca4f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_progress(df, y, label, span=1000):\n",
    "    idxs = df[\"index\"].values.tolist()\n",
    "    values = df[y].values.tolist()\n",
    "\n",
    "    window_sums  = accumulate(map(sub, values, chain(repeat(0,span),values)))\n",
    "    window_sizes = chain(range(1,span), repeat(span))\n",
    "    moving_averages_list = list(map(truediv,window_sums,window_sizes))\n",
    "    \n",
    "    fig, ax = plt.subplots(1, 1)\n",
    "    fig.set_size_inches(8, 6)\n",
    "    \n",
    "    ax.plot(idxs,moving_averages_list)\n",
    "    ax.set_ylabel(label,fontsize=18)\n",
    "    ax.set_xlabel('Interactions',fontsize=18)\n",
    "    \n",
    "    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))\n",
    "\n",
    "    ax.tick_params(labelsize=14)\n",
    "    for axis in ['top','bottom','left','right']:\n",
    "        ax.spines[axis].set_linewidth(2)\n",
    "    ax.tick_params(width=2)\n",
    "\n",
    "    plt.savefig(y+'.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe09c061",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-09-29 06:39:49 -- Processing chunk...\n",
      "2022-09-29 06:39:49 --   * Recording Learner 0 parameters... (0.0 seconds) (completed)\n",
      "2022-09-29 06:40:01 --   * Loading {'openml_data': 150, 'openml_task': None, 'openml_target': None, 'cat_as_str': False, 'drop_missing': True, 'take': 500000, 'label_type': None, 'type': 'SupervisedSimulation'}... (11.52 seconds) (completed)\n",
      "2022-09-29 06:40:32 --   * Creating Environment 0 from Loaded Source... (31.36 seconds) (completed)\n",
      "2022-09-29 06:40:33 --   * Recording Environment 0 statistics... (0.0 seconds) (completed)\n"
     ]
    }
   ],
   "source": [
    "def runExperiment():\n",
    "    import matplotlib.pyplot as plt\n",
    "    from matplotlib.font_manager import FontProperties\n",
    "    #this will cache the openml dataset so you don't have to download it more than once\n",
    "    CobaContext.cacher.cache_directory = './.coba_cache'\n",
    "\n",
    "    config = {\"processes\": 1, \"chunk_by\":'task', 'maxchunksperchild': 0 }\n",
    "    log_file = None\n",
    "\n",
    "    covertype_id = 150\n",
    "    ndata = 500000\n",
    "    \n",
    "    n_users = 100\n",
    "    n_words = 100\n",
    "    to_recommender_igl = ToRecommenderIGL(n_users=n_users, n_normal=n_users//2, n_words=n_words, n_good=n_words//2, seed=1)\n",
    "\n",
    "    environments = Environments.from_supervised(Pipes.join(OpenmlSource(data_id=150), Take(ndata))).scale(shift='min').shuffle(seeds=2).filter(to_recommender_igl)\n",
    "    learner = PythonIGL(n_users=n_users, n_feedbacks=n_words, d_embed=2, sigma=0.2, n_rff=512, lr=2.5e-3, epsilon=1, epsilon_t0=1000)\n",
    "\n",
    "    evaluator = EvalRecommenderIGL()\n",
    "    res = Experiment(environments, [learner], evaluation_task=evaluator).config(**config).run(log_file)\n",
    "\n",
    "    plot_progress(res.interactions.to_pandas(),'reward','Reward')\n",
    "    plot_progress(res.interactions.to_pandas(),'epsilon','Epsilon')\n",
    "    plot_progress(res.interactions.to_pandas(),'ikacc','IK Accuracy')\n",
    "    pdata = res.interactions.to_pandas()\n",
    "    print(pdata.groupby(['reward', 'fakereward'])[['reward', 'fakereward']].count())\n",
    "    with torch.no_grad():\n",
    "        uembed = evaluator._savelearner.getUserEmbedding(torch.Tensor(to_recommender_igl.userids).long()).numpy()\n",
    "        wordembed = evaluator._savelearner.getFeedbackEmbedding(torch.Tensor(to_recommender_igl.wordids).long()).numpy()\n",
    "        \n",
    "        with plt.style.context('seaborn-colorblind'):    \n",
    "            fig, ax = plt.subplots(1, 2)\n",
    "            fig.set_size_inches(16, 6)\n",
    "            \n",
    "            dotsize = 20*(2**2)\n",
    "            \n",
    "            nx, ny = list(zip(*[ (x, y) for x, y, u in zip(uembed[:,0], uembed[:,1], to_recommender_igl.userids) if u in to_recommender_igl.normalids ]))\n",
    "            ax[0].scatter(nx, ny, c='C2', s=dotsize, marker='x', label='normal', alpha=0.5)\n",
    "            bx, by = list(zip(*[ (x, y) for x, y, u in zip(uembed[:,0], uembed[:,1], to_recommender_igl.userids) if u not in to_recommender_igl.normalids ]))\n",
    "            ax[0].scatter(bx, by, c='C1', s=dotsize, marker='o', label='bizarro', alpha=0.5)\n",
    "            ax[0].legend(scatterpoints=1)\n",
    "            ax[0].set_title('User Embeddings', fontsize=22)\n",
    "            \n",
    "            gx, gy = list(zip(*[ (x, y) for x, y, u in zip(wordembed[:,0], wordembed[:,1], to_recommender_igl.wordids) if u in to_recommender_igl.good_words ]))\n",
    "            ax[1].scatter(gx, gy, c='C2', s=dotsize, marker='x', label='good', alpha=0.5)\n",
    "            bx, by = list(zip(*[ (x, y) for x, y, u in zip(wordembed[:,0], wordembed[:,1], to_recommender_igl.wordids) if u in to_recommender_igl.bad_words ]))\n",
    "            ax[1].scatter(bx, by, c='C1', s=dotsize, marker='o', label='bad', alpha=0.5)\n",
    "            ax[1].legend(scatterpoints=1)\n",
    "            ax[1].set_title('Word Embeddings', fontsize=22)\n",
    "            \n",
    "            for a in ax:\n",
    "                a.tick_params(labelsize=18)\n",
    "                a.legend(fontsize=20)\n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    a.spines[axis].set_linewidth(2)\n",
    "                a.tick_params(width=2)\n",
    "            \n",
    "            plt.savefig('userwordembeddings.pdf')\n",
    "    \n",
    "runExperiment()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "igl",
   "language": "python",
   "name": "igl"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
