{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "820dcead-426d-4851-a48b-8432b2b8ae4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import random\n",
    "import time\n",
    "\n",
    "import pickle\n",
    "import importlib\n",
    "from tqdm.notebook import tqdm,trange\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "267763f9-4159-4123-918e-9eaf2f6ec0f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]\n",
    "device = devices[2] # select GPU on the machine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "452ae82c-b947-431b-b4d0-50b68bd4d466",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "from utils import convert_to_tensor, build_darkroom_data_filename, build_darkroom_model_filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4adae7c-c58e-4ab6-b8a6-38a046ab1389",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_hists, n_samples, horizon, dim = 1, 1, 100, 10\n",
    "dataset_config = {\n",
    "        'n_hists': n_hists,\n",
    "        'n_samples': n_samples,\n",
    "        'horizon': horizon,\n",
    "        'dim': dim,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "910ea624-9ba9-4304-b063-9994b4254e0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dim = 2\n",
    "action_dim = 5\n",
    "env = 'darkroom_pref_step'\n",
    "n_envs = 100000\n",
    "random_p=(0.2,1.0)\n",
    "\n",
    "path_train = build_darkroom_data_filename(\n",
    "    env, n_envs, dataset_config, random_p, mode=0)\n",
    "path_test = build_darkroom_data_filename(\n",
    "    env, n_envs, dataset_config, random_p, mode=1)\n",
    "\n",
    "with open(path_train, 'rb') as f:\n",
    "    train_trajs = pickle.load(f)\n",
    "print(path_train, '\\n', path_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6e47938-afca-47d0-855d-1ad52445cd11",
   "metadata": {},
   "source": [
    "## Set-Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d66adcbb-5e77-45e1-9ce4-41a5cd915b1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2Config, GPT2Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "941d853f-c506-4c74-991f-8b03a50a2cab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from net import Transformer\n",
    "from dataset import DatasetBatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dfcaf82-dd42-444e-880b-21ba8c83b9c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_trajs[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99759166-189c-4d82-8f46-4ff650432886",
   "metadata": {},
   "outputs": [],
   "source": [
    "horizon, state_dim, action_dim, n_layer, n_embd, n_head,  dropout = 100, 2, 5, 4, 64, 4, False\n",
    "config = {\n",
    "        'horizon': horizon,\n",
    "        'state_dim': state_dim,\n",
    "        'action_dim': action_dim,\n",
    "        'n_layer': n_layer,\n",
    "        'n_embd': n_embd,\n",
    "        'n_head': n_head,\n",
    "        'dropout': dropout,\n",
    "        'test': False,\n",
    "        'shuffle':False,\n",
    "        'store_gpu':False\n",
    "}\n",
    "# filename = build_darkroom_model_filename(env, model_config)\n",
    "opt_config = {\n",
    "        'batch_size': 64,\n",
    "        'shuffle': True,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4c0191a-eb16-4d91-844b-3e5e126d0eb9",
   "metadata": {},
   "source": [
    "## Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc02369-c1db-4b86-92f4-cc41b6b2f722",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import PrefDatasetBatch\n",
    "from net import PrefTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7089893-be28-40ba-bb73-7375ae20eef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_pref_batch(batch, model):\n",
    "    batch_size, horizon, _ = batch['context_states'].shape\n",
    "    step_indices = np.random.choice(range(horizon),batch_size,replace=True)\n",
    "    qs = batch['context_states'][torch.arange(batch_size),step_indices]\n",
    "    batch['query_states'] = qs\n",
    "\n",
    "    # pr_actions and npr_actions have shape: B x H x num_A\n",
    "    pr_actions = batch['context_pr_actions'][torch.arange(batch_size),step_indices].unsqueeze(1).repeat(1,horizon+1,1)\n",
    "    npr_actions = batch['context_npr_actions'][torch.arange(batch_size),step_indices].unsqueeze(1).repeat(1,horizon+1,1)\n",
    "    pr_indices = pr_actions.argmax(dim=-1)\n",
    "    npr_indices = npr_actions.argmax(dim=-1) \n",
    "\n",
    "    # true_pr_actions = true_pr_actions.unsqueeze(1).repeat(1,horizon+1,1).reshape(-1, action_dim)\n",
    "    # true_npr_actions = true_npr_actions.unsqueeze(1).repeat(1,horizon+1,1).reshape(-1, action_dim)\n",
    "    # preds = model(batch).reshape(-1,action_dim)\n",
    "\n",
    "    logp = torch.log(torch.softmax(model(batch),dim=-1)) # B x H x num_A\n",
    "    #both log_pr_prob and log_npr_prob has shape B x H\n",
    "    log_pr_prob = logp.gather(dim=-1, index=pr_indices.unsqueeze(-1)).squeeze(-1) \n",
    "    log_npr_prob = logp.gather(dim=-1, index=npr_indices.unsqueeze(-1)).squeeze(-1) \n",
    "    \n",
    "    return log_pr_prob, log_npr_prob\n",
    "\n",
    "tsig = torch.sigmoid\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def pref_loss(log_pr_prob, log_npr_prob, beta=10.0,gamma=0.5):\n",
    "    # loss = -torch.mean(torch.log(tsig(beta*(log_pr_prob-gamma*log_npr_prob)).flatten()))\n",
    "    loss = torch.mean(F.softplus(-beta*(log_pr_prob-gamma*log_npr_prob)))\n",
    "    # loss = -torch.mean(torch.log(tsig(beta*(log_pr_prob-gamma*log_npr_prob)).flatten()))\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1c40bb-2da5-45b8-ab01-9f5b346e71a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_epoch(model, test_loader, device, beta=20.0, gamma=0.5):\n",
    "    test_loss = 0.0\n",
    "    with torch.no_grad():\n",
    "        for batch_id, batch in enumerate(test_loader):\n",
    "            batch = {k: v.to(device) for k, v in batch.items()}\n",
    "\n",
    "            log_pr_prob, log_npr_prob = construct_pref_batch(batch,model)\n",
    "            loss = pref_loss(log_pr_prob, log_npr_prob, beta=beta, gamma=gamma)\n",
    "            test_loss += loss.item()\n",
    "    \n",
    "    return test_loss/batch_id\n",
    "\n",
    "\n",
    "def train_epoch(model, train_loader, device, optimizer, beta=20.0, gamma=0.5):\n",
    "    train_loss = 0.0\n",
    "    progress_bar = tqdm(train_loader)\n",
    "    for batch_id, batch in enumerate(progress_bar):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        \n",
    "        log_pr_prob, log_npr_prob = construct_pref_batch(batch,model)\n",
    "        # print(log_pr_prob, log_npr_prob)\n",
    "        optimizer.zero_grad()\n",
    "        loss = pref_loss(log_pr_prob, log_npr_prob, beta=beta, gamma=gamma)\n",
    "        # print(loss)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_value_(model.parameters(), 1.0)\n",
    "        optimizer.step()\n",
    "        \n",
    "        train_loss += loss.item()\n",
    "        if batch_id % 200 == 0:\n",
    "            progress_bar.set_description(f'{loss.item():.2f}')\n",
    "    \n",
    "    return train_loss/batch_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39d5cd25-943b-48a7-89f7-d4213900f6a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = PrefDatasetBatch(path_train, config, device, num_trajs=80000)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, **opt_config)\n",
    "test_dataset = PrefDatasetBatch(path_test, config, device, num_trajs=10000)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, **opt_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46c27201-5ab4-47c5-a432-a4583ff680c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(16)\n",
    "\n",
    "epochs = 10\n",
    "lr=1e-4\n",
    "gamma=0.1\n",
    "beta=1.0\n",
    "\n",
    "\n",
    "    \n",
    "model = PrefTransformer(config).to(device)\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)\n",
    "\n",
    "test_losses, train_losses = [], []\n",
    "for epoch in trange(epochs):\n",
    "    print(f'============ Epoch:{epoch} ============')\n",
    "    \n",
    "    test_loss = eval_epoch(model, test_loader, device, beta=beta, gamma=gamma)\n",
    "    print(f'test loss:{test_loss}')\n",
    "    test_losses.append(test_loss)\n",
    "\n",
    "    train_loss = train_epoch(model, train_loader, device, optimizer, beta=beta, gamma=gamma)\n",
    "    print(f'train loss:{train_loss}')\n",
    "    train_losses.append(train_loss)\n",
    "    \n",
    "    if (epoch+1)%2==0:\n",
    "        torch.save(model,f'trained_models/DRS_gamma_{gamma}_beta_{beta}_epoch_{epoch}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "762b2d21-66c7-4407-891e-30cb5c89a895",
   "metadata": {},
   "outputs": [],
   "source": [
    "Heps = 40\n",
    "online_pref(eval_trajs_set[0], model.to(devices[0]),devices[0], Heps, 100, len(eval_trajs), 10, 100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "in-context-learning",
   "language": "python",
   "name": "in-context-learning"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
