{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import copy\n",
    "import json\n",
    "import pickle\n",
    "import os\n",
    "import random\n",
    "import re\n",
    "import string\n",
    "import math\n",
    "from datetime import datetime\n",
    "\n",
    "import evaluate\n",
    "import torch\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "from transformers import LlamaTokenizer\n",
    "from tqdm import tqdm\n",
    "\n",
    "from eval import *\n",
    "from llama.metrics import *\n",
    "from llama.generation import Llama\n",
    "from llama.mixed_generation import MixedLlama\n",
    "from llama.tokenizer import Tokenizer\n",
    "from ngrams.ngram_models import make_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mixing_options = [\"sample\", \"sample_new_weights_with_score\", \"sample_weights_with_current\"]\n",
    "smoothing_options = [None, \"geom\", \"all\"]\n",
    "n_drafts = 3 #2, 3\n",
    "n_token_sample = n_drafts * 3\n",
    "n_token_consider = 32000\n",
    "tokenizer = Tokenizer(\"../7B/tokenizer.model\") # LlamaTokenizer.from_pretrained(\"./7B_HF\", add_bos_token=True)\n",
    "mixing_method = mixing_options[1]\n",
    "smoothing = smoothing_options[0]\n",
    "sample_tokens = False\n",
    "sample_beams = False\n",
    "prompt_len = 40\n",
    "max_gen_len = 40\n",
    "ckpt_path = \"../ckpts-200k\"\n",
    "\n",
    "# weighting\n",
    "i_weights = [0.01, 0.04, 0.15, 0.18, 0.12]\n",
    "i_length = [1, 2, 3, 4, 5]\n",
    "alpha = 0.6\n",
    "temp = 0.06"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Making bigram...\n",
      "1310800\n",
      "Making trigram...\n",
      "671088728\n",
      "Making fourgram...\n",
      "2684354648\n"
     ]
    }
   ],
   "source": [
    "if ckpt_path is not None:\n",
    "    ngrams = make_models(ckpt_path, bigram=True, trigram=True, fourgram=True, fivegram=False, sixgram=False, sevengram=False)\n",
    "else:\n",
    "    ngrams = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mixed_device = torch.device(\"cuda:0\")\n",
    "reg_device = torch.device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"RANK\"] = \"0\"\n",
    "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
    "os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n",
    "os.environ[\"MASTER_PORT\"] = \"9988\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> initializing model parallel with size 1\n",
      "> initializing ddp with size 1\n",
      "> initializing pipeline with size 1\n",
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/temp/miniconda3/envs/mixed/lib/python3.11/site-packages/torch/__init__.py:696: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
      "  _C._set_default_tensor_type(t)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded in 7.07 seconds\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "weight_path = \"../7B/\"\n",
    "mixed_model = MixedLlama.build(ckpt_dir=weight_path, \n",
    "                                 tokenizer_path='../7B/tokenizer.model', \n",
    "                                 max_seq_len=1000, \n",
    "                                 max_batch_size=16,\n",
    "                                 device=mixed_device,\n",
    "                                 model_parallel_size=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nucleus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/temp/miniconda3/envs/mixed/lib/python3.11/site-packages/torch/__init__.py:696: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
      "  _C._set_default_tensor_type(t)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded in 6.86 seconds\n"
     ]
    }
   ],
   "source": [
    "reg_model = Llama.build(ckpt_dir=\"../7B/\", \n",
    "                    tokenizer_path='../7B/tokenizer.model', \n",
    "                    max_seq_len=1000, \n",
    "                    max_batch_size=16,\n",
    "                    device=mixed_device, # reg_device,\n",
    "                    model_parallel_size=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length: 7993\n"
     ]
    }
   ],
   "source": [
    "trivia_path = \"../../datasets/qa/wikipedia-dev.json\"\n",
    "with open(trivia_path, \"r\") as f:\n",
    "    triviaqa = json.load(f)[\"Data\"]\n",
    "print(f\"Length: {len(triviaqa)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 7993/7993 [00:01<00:00, 7558.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "207\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# find longest \n",
    "longest = 0\n",
    "for sample in tqdm(triviaqa):\n",
    "    for answer in sample[\"Answer\"][\"Aliases\"]:\n",
    "        tmp = tokenizer.encode([answer], False, False)[0]\n",
    "        if len(tmp) > longest:\n",
    "            longest = len(tmp)\n",
    "max_gen_len = math.ceil(1.5 * longest)\n",
    "print(max_gen_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_dtype(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_types = [\"mixed\", \"regular\"]\n",
    "model_type = model_types[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/triviaqa/default.yaml\n",
    "def evaluate_trivia(model_type, question, max_gen_len):\n",
    "    question = \"Question: \" + question + \"\\nAnswer:\"\n",
    "    text_len = len(question) # for truncating\n",
    "    prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
    "    if model_type == \"regular\":\n",
    "        input = [question for _ in range(n_drafts)]\n",
    "        # print(input)\n",
    "        sequences, _ = evaluate_nucleus_losses(data=input,\n",
    "                                               model=reg_model,\n",
    "                                               tokenizer=tokenizer,\n",
    "                                               prompt_len=prompt_len,\n",
    "                                               max_gen_len=max_gen_len,\n",
    "                                               temp=0.6,\n",
    "                                               bsz=8,\n",
    "                                               marker=False)\n",
    "        n_pd, seq_len = sequences.shape\n",
    "    elif model_type == \"mixed\":\n",
    "        sequences, _ = evaluate_mixed_losses(data=[question],\n",
    "                                                   model=mixed_model,\n",
    "                                                   tokenizer=tokenizer,\n",
    "                                                   prompt_len=prompt_len,\n",
    "                                                   max_gen_len=max_gen_len,\n",
    "                                                   alpha=alpha,\n",
    "                                                   temp=temp,\n",
    "                                                   n_drafts=n_drafts,\n",
    "                                                   n_token_consider=n_token_consider,\n",
    "                                                   n_token_sample=n_token_sample,\n",
    "                                                   mixing_method=mixing_method,\n",
    "                                                   smoothing=smoothing,\n",
    "                                                   debug=False,\n",
    "                                                   bsz=8, # for timing\n",
    "                                                   i_weights=i_weights[:3],\n",
    "                                                   i_length=i_length[:3],\n",
    "                                                   ngrams=ngrams,\n",
    "                                                   sample_beams=sample_beams,\n",
    "                                                   sample_tokens=sample_tokens,\n",
    "                                                   marker=False)\n",
    "        n_p, n_d, seq_len = sequences.shape\n",
    "    sequences = sequences.reshape(-1, seq_len).tolist()\n",
    "    for d_idx in range(len(sequences)):\n",
    "        draft = sequences[d_idx]\n",
    "        if -1 in draft:\n",
    "            draft = draft[:draft.index(-1)]\n",
    "        sequences[d_idx] = draft\n",
    "    decoded_seq = tokenizer.decode(sequences)\n",
    "    answers = []\n",
    "    for s in decoded_seq:\n",
    "        # print(s)\n",
    "        answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
    "    return answers\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precision from 1 to 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████| 7993/7993 [35:41<00:00,  3.73it/s]\n"
     ]
    }
   ],
   "source": [
    "questions = {}\n",
    "predictions = {}\n",
    "print(f\"Precision from 1 to {n_drafts}\")\n",
    "for sample in tqdm(triviaqa):\n",
    "    # adaptive gen\n",
    "    longest = 0\n",
    "    shortest = 1000\n",
    "    total = 0\n",
    "    for answer in sample[\"Answer\"][\"Aliases\"]:\n",
    "        tmp = tokenizer.encode([answer], False, False)[0]\n",
    "        if len(tmp) > longest:\n",
    "            longest = len(tmp)\n",
    "        if len(tmp) < shortest:\n",
    "            shortest = len(tmp)\n",
    "        total += len(tmp)\n",
    "    \n",
    "    # inf\n",
    "    id = sample[\"QuestionId\"]\n",
    "    question = sample[\"Question\"]\n",
    "    answer = evaluate_trivia(model_type, question, max_gen_len=longest + 3)\n",
    "    predictions[id] = answer\n",
    "    questions[id] = question\n",
    "    # if len(questions) == 20:\n",
    "    #     break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "precisions = {}\n",
    "for i in range(1, n_drafts+1):\n",
    "    prec = str(i)\n",
    "    responses = {k: v[:i] for k, v in predictions.items()}\n",
    "    precisions[prec] = responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Which Lloyd Webber musical premiered in the US on 10th December 1993?\n",
      "['The Phantom of the Opera', 'The Phantom of the Opera', 'Sun Phantom of the Opera']\n",
      "================\n",
      "Who was the next British Prime Minister after Arthur Balfour?\n",
      "['H', 'H', 'David']\n",
      "================\n",
      "Who had a 70s No 1 hit with Kiss You All Over?\n",
      "['Exile', 'Exile', 'Exile']\n",
      "================\n",
      "What claimed the life of singer Kathleen Ferrier?\n",
      "['Cancer', 'Cancer', 'Cancer of']\n",
      "================\n",
      "Which actress was voted Miss Greenwich Village in 1942?\n",
      "['Lauren Bacall', 'Lauren Bacall', 'Maruren Bacall']\n",
      "================\n",
      "What was the name of Michael Jackson's autobiography written in 1988?\n",
      "['Moonwalk', 'Moonwalk', 'Moonwalk']\n",
      "================\n",
      "Which volcano in Tanzania is the highest mountain in Africa?\n",
      "['Mount Kilimanjaro is the highest mountain in Africa', 'Mount Kilimanjaro is the highest mountain in Africa', 'Mount Kilimanjaro']\n",
      "================\n",
      "The flag of Libya is a plain rectangle of which color?\n",
      "['Green', 'Green', 'Green']\n",
      "================\n",
      "Of which African country is Niamey the capital?\n",
      "['Niger', 'Niger', 'Niger']\n",
      "================\n",
      "Which musical featured the song The Street Where You Live?\n",
      "['My Fair Lady', 'My Fair Lady', 'My Fair Lady']\n",
      "================\n"
     ]
    }
   ],
   "source": [
    "counter = 0\n",
    "for k in predictions:\n",
    "    if counter >= 10:\n",
    "        break\n",
    "    print(questions[k])\n",
    "    print(predictions[k])\n",
    "    counter += 1\n",
    "    print(\"================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../trivia/\", exist_ok=True)\n",
    "for prec in range(1, n_drafts+1):\n",
    "    out_path = f\"../nucleus_extra/trivia_extra/ngram_4trivia_{model_type}_{prec}_4.json\"\n",
    "    with open(out_path, \"w\") as f:\n",
    "        json.dump(precisions[str(prec)], f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mixed",
   "language": "python",
   "name": "mixed"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
