{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import json\n",
    "import os\n",
    "import re\n",
    "from datetime import datetime\n",
    "\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from tqdm import tqdm\n",
    "\n",
    "from eval import *\n",
    "from llama.metrics import *\n",
    "from llama.generation import Llama\n",
    "from llama.mixed_generation import MixedLlama\n",
    "from llama.tokenizer import Tokenizer\n",
    "from ngrams.ngram_models import make_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0a09c14aeffd49f9a85cf05d1b5be1b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/8.77k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading data: 100%|████████████████████████████████████████████████████████████| 4.46M/4.46M [00:00<00:00, 9.02MB/s]\n",
      "Downloading data: 100%|██████████████████████████████████████████████████████████████| 214k/214k [00:00<00:00, 2.03MB/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56e61c94ecc74844b611a4cc02395e38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/87925 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9543d790f6f948ea923ca22c867ca758",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/3610 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "nq = load_dataset(\"nq_open\")[\"validation\"]\n",
    "n_drafts = 5\n",
    "n_token_sample = 3 * n_drafts\n",
    "n_token_consider = 32000\n",
    "alpha = 0.6\n",
    "temp = 0.06\n",
    "mixing_method = \"sample_new_weights_with_score\"\n",
    "smoothing = None\n",
    "sample_tokens = False\n",
    "sample_beams = False\n",
    "i_weights = [0.01, 0.04, 0.15, 0.18, 0.12]\n",
    "i_length = [1, 2, 3, 4, 5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Making bigram...\n",
      "1310800\n",
      "Making trigram...\n",
      "671088728\n",
      "Making fourgram...\n",
      "2684354648\n",
      "Making fivegram...\n",
      "5368709200\n",
      "Making sixgram...\n",
      "5368709200\n"
     ]
    }
   ],
   "source": [
    "ngrams = make_models(\"../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = Tokenizer(\"../tokenizer.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mixed_device = torch.device(\"cuda:0\")\n",
    "reg_device = torch.device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> initializing model parallel with size 1\n",
      "> initializing ddp with size 1\n",
      "> initializing pipeline with size 1\n",
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
      "  _C._set_default_tensor_type(t)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded in 35.50 seconds\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "# load mixed\n",
    "weight_path = \"../7B/\"\n",
    "mixed_model = MixedLlama.build(ckpt_dir=weight_path, \n",
    "                                 tokenizer_path='../tokenizer.model', \n",
    "                                 max_seq_len=1000, \n",
    "                                 max_batch_size=16,\n",
    "                                 device=mixed_device,\n",
    "                                 model_parallel_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "Loaded in 24.35 seconds\n"
     ]
    }
   ],
   "source": [
    "# load regular\n",
    "reg_model = Llama.build(ckpt_dir=\"../7B/\", \n",
    "                    tokenizer_path='../tokenizer.model', \n",
    "                    max_seq_len=1000, \n",
    "                    max_batch_size=16,\n",
    "                    device=reg_device, # reg_device,\n",
    "                    model_parallel_size=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_types = [\"greedy\", \"mixed\", \"regular\"]\n",
    "model_type = model_types[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_nq(model_type, question, max_gen_len):\n",
    "    question = \"Answer these questions:\\n\\nQ: \" + question + \"?\\nA:\"\n",
    "    text_len = len(question) # for truncating\n",
    "    prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
    "    if model_type == \"regular\" or model_type == \"greedy\":\n",
    "        if model_type == \"regular\":\n",
    "            input = [question for _ in range(n_drafts)]\n",
    "            print(input)\n",
    "            sequences, _ = evaluate_nucleus_losses(data=input,\n",
    "                                                   model=reg_model,\n",
    "                                                   tokenizer=tokenizer,\n",
    "                                                   prompt_len=prompt_len,\n",
    "                                                   max_gen_len=max_gen_len,\n",
    "                                                   temp=0.6,\n",
    "                                                   bsz=8,\n",
    "                                                   marker=False)\n",
    "        else:\n",
    "            sequences, _ = evaluate_nucleus_losses(data=[question],\n",
    "                                       model=reg_model,\n",
    "                                       tokenizer=tokenizer,\n",
    "                                       prompt_len=prompt_len,\n",
    "                                       max_gen_len=max_gen_len,\n",
    "                                       temp=0,\n",
    "                                       bsz=8,\n",
    "                                       marker=False)\n",
    "        n_pd, seq_len = sequences.shape\n",
    "    elif model_type == \"mixed\":\n",
    "        if prompt_len <= 10:\n",
    "            alpha = 0.34\n",
    "            temp = 0.12\n",
    "        elif prompt_len > 10 and prompt_len <= 20:\n",
    "            alpha = 0.6\n",
    "            temp = 0.12\n",
    "        elif prompt_len > 20 and prompt_len <= 30:\n",
    "            alpha = 0.5\n",
    "            temp = 0.12\n",
    "        else:\n",
    "            alpha = 0.55\n",
    "            temp = 0.1\n",
    "        sequences, _ = evaluate_mixed_losses(data=[question],\n",
    "                                                   model=mixed_model,\n",
    "                                                   tokenizer=tokenizer,\n",
    "                                                   prompt_len=prompt_len,\n",
    "                                                   max_gen_len=max_gen_len,\n",
    "                                                   alpha=alpha,\n",
    "                                                   temp=temp,\n",
    "                                                   n_drafts=n_drafts,\n",
    "                                                   n_token_consider=n_token_consider,\n",
    "                                                   n_token_sample=n_token_sample,\n",
    "                                                   mixing_method=mixing_method,\n",
    "                                                   smoothing=smoothing,\n",
    "                                                   debug=False,\n",
    "                                                   bsz=8, # for timing\n",
    "                                                   i_weights=i_weights,\n",
    "                                                   i_length=i_length,\n",
    "                                                   ngrams=ngrams,\n",
    "                                                   sample_beams=sample_beams,\n",
    "                                                   sample_tokens=sample_tokens,\n",
    "                                                   marker=False)\n",
    "        n_p, n_d, seq_len = sequences.shape\n",
    "    sequences = sequences.reshape(-1, seq_len).tolist()\n",
    "    for d_idx in range(len(sequences)):\n",
    "        draft = sequences[d_idx]\n",
    "        if -1 in draft:\n",
    "            draft = draft[:draft.index(-1)]\n",
    "        sequences[d_idx] = draft\n",
    "    decoded_seq = tokenizer.decode(sequences)\n",
    "    answers = []\n",
    "    for s in decoded_seq:\n",
    "        answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
    "    return answers\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precision from 1 to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████| 3610/3610 [22:28<00:00,  2.68it/s]\n"
     ]
    }
   ],
   "source": [
    "predictions = []\n",
    "print(f\"Precision from 1 to {n_drafts}\")\n",
    "for sample in tqdm(nq):\n",
    "    # adaptive gen\n",
    "    longest = 0\n",
    "    shortest = 1000\n",
    "    for answer in sample[\"answer\"]:\n",
    "        tmp = tokenizer.encode([answer], False, False)[0]\n",
    "        if len(tmp) > longest:\n",
    "            longest = len(tmp)\n",
    "        if len(tmp) < shortest:\n",
    "            shortest = len(tmp)\n",
    "    \n",
    "    # inf\n",
    "    question = sample[\"question\"]\n",
    "    answer = evaluate_nq(model_type, question, max_gen_len=shortest+3)\n",
    "    predictions.append({\"question\": question, \"answer\": answer})\n",
    "    # if len(predictions) > 20:\n",
    "    #     break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "precisions = {}\n",
    "for i in range(1, n_drafts+1):\n",
    "    prec = str(i)\n",
    "    responses = []\n",
    "    for result in predictions:\n",
    "        responses.append({\"question\": result[\"question\"], \"answer\": result[\"answer\"][:i]})\n",
    "    precisions[prec] = responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'question': 'when was the last time anyone was on the moon', 'answer': ['2019', '2019', '2019-', '2019-', '1019']}\n",
      "================\n",
      "{'question': \"who wrote he ain't heavy he's my brother lyrics\", 'answer': ['The song was written by', 'The lyr was written by', 'The Hol was written by', 'Neil song was written by', 'Neil lyr was written by']}\n",
      "================\n",
      "{'question': 'how many seasons of the bastard executioner are there', 'answer': ['1', 'There1', 'there1', '1', 'There1']}\n",
      "================\n",
      "{'question': 'when did the eagles win last super bowl', 'answer': ['2018', 'The2018', '1018', '2017', 'the2018']}\n",
      "================\n",
      "{'question': \"who won last year's ncaa women's basketball\", 'answer': ['the university of connecticut', 'The university of connecticut', 'university of connecticut', 'the University of connecticut', 'The University of connecticut']}\n",
      "================\n",
      "{'question': 'when did the isle of wight become an island', 'answer': ['1207', 'when1207', '1287', '1277', 'when1287']}\n",
      "================\n",
      "{'question': 'love yourself by justin bieber is about who', 'answer': ['love yourself by justin b', 'love yourself is justin b', 'Justin yourself by justin b', 'Justin yourself is justin b', 'It yourself by justin b']}\n",
      "================\n",
      "{'question': 'who was the ruler of england in 1616', 'answer': ['James I', 'James I of', 'King I', 'j I', 'James I']}\n",
      "================\n",
      "{'question': 'what is the hot coffee mod in san andreas', 'answer': ['The Hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a mod for Grand', 'The hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a modification that Grand', 'It Hot Coffee mod is a modification for Grand']}\n",
      "================\n",
      "{'question': 'what is the maximum data rate for the 802.11a standard select one', 'answer': ['54 Mbps', '54Mbps', '54 mbps', '54 Mbps', '54 Mbps']}\n",
      "================\n"
     ]
    }
   ],
   "source": [
    "counter = 0\n",
    "for k in predictions:\n",
    "    if counter >= 10:\n",
    "        break\n",
    "    print(k)\n",
    "    counter += 1\n",
    "    print(\"================\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['1', '2', '3', '4', '5'])\n"
     ]
    }
   ],
   "source": [
    "os.makedirs(\"../nq/\", exist_ok=True)\n",
    "print(precisions.keys())\n",
    "for prec in range(1, n_drafts+1):\n",
    "    out_path = f\"../nq/eval_{model_type}_{prec}_test.jsonl\"\n",
    "    with open(out_path, \"w\") as f:\n",
    "        for obj in precisions[str(prec)]:    \n",
    "            f.write(json.dumps(obj) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for prec in range(1, n_drafts+1):\n",
    "    out_path = f\"../nq/eval_{model_type}_{prec}.jsonl\"\n",
    "    with open(out_path, \"r\") as f:\n",
    "        with open(f\"../nq/eval_{model_type}_1_{prec}.jsonl\", \"w\") as r:\n",
    "            for line in f:\n",
    "                k = json.loads(line)\n",
    "                k[\"answer\"] = [k[\"answer\"][prec-1]]\n",
    "                r.write(json.dumps(k) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mixed",
   "language": "python",
   "name": "mixed"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
