{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "63c5eae4-7248-46c4-bf70-405f9251a3cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import json\n",
    "import random\n",
    "import pickle\n",
    "from datetime import datetime\n",
    "from heapq import heappush, heappop, heapify, nsmallest, nlargest\n",
    "\n",
    "import evaluate\n",
    "import torch\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "from transformers import LlamaTokenizer\n",
    "from tqdm import tqdm\n",
    "\n",
    "from eval import *\n",
    "from llama.metrics import *\n",
    "from llama.generation import Llama\n",
    "from llama.mixed_generation import MixedLlama\n",
    "from llama.tokenizer import Tokenizer\n",
    "from ngrams.ngram_models import make_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "91324579-0a1f-44fe-a8c9-03c9cc831f69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data\n",
    "with open(\"../../gpt-2-output-dataset/data/webtext.valid.jsonl\", \"r\") as f:\n",
    "    dataset = [json.loads(line)[\"text\"] for line in f]\n",
    "mixing_options = [\"sample\", \"sample_new_weights_with_score\", \"sample_weights_with_current\"]\n",
    "smoothing_options = [None, \"geom\", \"all\"]\n",
    "data = dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b5aaca9e-91ac-428b-81bd-07396286606a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# params\n",
    "prompt_len = 15\n",
    "max_gen_len = 10\n",
    "n_drafts = 3\n",
    "n_token_sample = 3 * n_drafts\n",
    "n_token_consider = 32000\n",
    "bsz = 32\n",
    "tokenizer = Tokenizer('../7B/tokenizer.model')\n",
    "mixing_method = mixing_options[1]\n",
    "smoothing = smoothing_options[1]\n",
    "sample_tokens = False\n",
    "sample_beams = False\n",
    "\n",
    "# weighting\n",
    "ckpt_path = None\n",
    "ckpt_path = \"../ckpts-200k\"\n",
    "i_weights = [0.01, 0.04, 0.15]#, 0.18, 0.12]\n",
    "i_length = [1, 2, 3]#, 4, 5]\n",
    "alpha = 0.54\n",
    "temp = 0.06\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ab8397eb-2c22-427a-92f4-b4b607769b72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Making bigram...\n",
      "1310800\n",
      "Making trigram...\n",
      "671088728\n",
      "Making fourgram...\n",
      "2684354648\n"
     ]
    }
   ],
   "source": [
    "if ckpt_path is not None:\n",
    "    ngrams = make_models(ckpt_path, bigram=True, trigram=True, fourgram=True, fivegram=False, sixgram=False, sevengram=False)\n",
    "else:\n",
    "    ngrams = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7bd44c5e-c557-4d2a-91ec-1c372ada56b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "mixed_device = torch.device(\"cuda:0\")\n",
    "reg_device = torch.device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0344b7aa-57ca-4a2f-aea8-b008e9fccfa3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> initializing model parallel with size 1\n",
      "> initializing ddp with size 1\n",
      "> initializing pipeline with size 1\n",
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/temp/miniconda3/envs/mixed/lib/python3.11/site-packages/torch/__init__.py:696: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
      "  _C._set_default_tensor_type(t)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded in 8.36 seconds\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "os.environ[\"RANK\"] = \"0\"\n",
    "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
    "os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n",
    "os.environ[\"MASTER_PORT\"] = \"9998\"\n",
    "mixed_model = MixedLlama.build(ckpt_dir=\"../7B/\", \n",
    "                         tokenizer_path='../7B/tokenizer.model', \n",
    "                         max_seq_len=60, \n",
    "                         max_batch_size=32,\n",
    "                         device=mixed_device,\n",
    "                         model_parallel_size=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "48105dab-58aa-4c09-8dbd-8200fdaf513b",
   "metadata": {},
   "outputs": [],
   "source": [
    "alphas = [0.44 + 0.02 * i for i in range(13)]\n",
    "temps = [0.04 + 0.02 * i for i in range(8)]\n",
    "alphas.append(0.55)\n",
    "# gen_lens = [5, 15, 20]\n",
    "# prompt_lens = [5, 25]\n",
    "total_len = len(alphas) * len(temps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "856f7198-b54e-4f80-b17d-3386c1a2e3d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                           | 0/112 [02:10<?, ?it/s]\n",
      "alpha: 0.55 temp: 0.18000000000000002 n_drafts: 3 prompt_len: 15: 100%|███████████████| 112/112 [06:56<00:00,  3.82s/it]"
     ]
    }
   ],
   "source": [
    "loop = tqdm(total=total_len, position=0, leave=True)\n",
    "results = {}\n",
    "dset = random.sample(dataset, 96)\n",
    "for alpha in alphas:\n",
    "    for temp in temps:\n",
    "        mixed_sequences, mixed_ppl = evaluate_mixed_losses(data=dset,\n",
    "                       model=mixed_model,\n",
    "                       tokenizer=tokenizer,\n",
    "                       prompt_len=prompt_len,\n",
    "                       max_gen_len=max_gen_len,\n",
    "                       alpha=alpha,\n",
    "                       temp=temp,\n",
    "                       n_drafts=n_drafts,\n",
    "                       n_token_consider=n_token_consider,\n",
    "                       n_token_sample=n_token_sample,\n",
    "                       mixing_method=mixing_method,\n",
    "                       smoothing=smoothing,\n",
    "                       debug=False,\n",
    "                       bsz=bsz,\n",
    "                       i_weights=i_weights,\n",
    "                       i_length=i_length,\n",
    "                       ngrams=ngrams,\n",
    "                       sample_beams=sample_beams,\n",
    "                       sample_tokens=sample_tokens,\n",
    "                       marker=False)\n",
    "        param = alpha, temp, n_drafts, prompt_len, n_token_sample, sample_tokens, sample_beams, mixing_method, max_gen_len\n",
    "        results[param] = mixed_sequences\n",
    "        loop.set_description(f\"alpha: {alpha} temp: {temp} n_drafts: {n_drafts} prompt_len: {prompt_len}\")\n",
    "        loop.update(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "517481b3-9ebf-44d0-8b8e-4d272d2c3f9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"../tuning/4gram_mixed_7B.pkl\", \"wb\") as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "42a67b4e-4f45-44bd-a6e5-c7acc6102ace",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(results, f\"../tuning/p{prompt_len}_d{n_drafts}_mixed_7B.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7c5aa812-92d5-4c40-bff6-c07e2978996e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open(f\"../tuning/glen_3_mixed_7B.pkl\", \"rb\") as f:\n",
    "#     r = pickle.load(f)\n",
    "# g_5 = {}\n",
    "# g_15 = {}\n",
    "# g_20 = {}\n",
    "# for params in r:\n",
    "#     alpha, temp, n_drafts, prompt_len, n_token_sample, sample_tokens, sample_beams, mixing_method, max_gen_len = params\n",
    "#     if max_gen_len == 5:\n",
    "#         g_5[params] = r[params]\n",
    "#     elif max_gen_len == 15:\n",
    "#         g_15[params] = r[params]\n",
    "#     elif max_gen_len == 20:\n",
    "#         g_20[params] = r[params]\n",
    "#     else:\n",
    "#         print(max_gen_len)\n",
    "# with open(f\"../tuning/g5_d3_mixed_7B.pkl\", \"wb\") as f:\n",
    "#     pickle.dump(g_5, f)\n",
    "# with open(f\"../tuning/g15_d3_mixed_7B.pkl\", \"wb\") as f:\n",
    "#     pickle.dump(g_15, f)\n",
    "# with open(f\"../tuning/g20_d3_mixed_7B.pkl\", \"wb\") as f:\n",
    "#     pickle.dump(g_20, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ac7f6ac-2345-42a7-88f7-f07f06696429",
   "metadata": {},
   "source": [
    "# Checking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5b76a496-66db-4f58-b59b-1a9c61188616",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(6.296875, (0.64, 0.08, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.38671875, (0.6, 0.08, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.53125, (0.55, 0.1, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.734375, (0.62, 0.08, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.58984375, (0.56, 0.1, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.55078125, (0.6799999999999999, 0.06, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.9140625, (0.5800000000000001, 0.04, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (7.27734375, (0.6, 0.1, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.8046875, (0.5, 0.12, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10)), (6.984375, (0.54, 0.1, 3, 15, 9, False, False, 'sample_new_weights_with_score', 10))]\n"
     ]
    }
   ],
   "source": [
    "# view tuning results\n",
    "import random\n",
    "new_dset = random.sample(dataset, 500)\n",
    "with open(f\"../tuning/4gram_mixed_7B_llama_tune.pkl\", \"rb\") as f:\n",
    "    results = pickle.load(f)\n",
    "heapify(results)\n",
    "top_10 = results[:10]\n",
    "print(top_10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "eb6ff535-47c5-4fe6-9646-d40f05231bbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [05:50<?, ?it/s]\n",
      "\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:44,  2.94s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:05<00:37,  2.65s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:34,  2.62s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:10<00:30,  2.58s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.54s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:15<00:25,  2.53s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:18<00:22,  2.53s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:20<00:20,  2.51s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:22<00:17,  2.51s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:25<00:15,  2.52s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:28<00:12,  2.52s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:30<00:10,  2.53s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:33<00:07,  2.55s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:35<00:05,  2.54s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:38<00:02,  2.54s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:39<00:00,  2.50s/it]\u001b[A\n",
      "alpha: 0.5 temp: 0.12 n_drafts: 3 prompt_len: 25:  10%|█         | 1/10 [00:40<06:00, 40.00s/it]\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:37,  2.50s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:34,  2.49s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:33,  2.54s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:10<00:30,  2.54s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.52s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:15<00:25,  2.52s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:17<00:22,  2.52s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:20<00:20,  2.51s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:22<00:17,  2.50s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:25<00:15,  2.53s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:27<00:12,  2.50s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:30<00:10,  2.50s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:32<00:07,  2.52s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:35<00:05,  2.52s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:37<00:02,  2.51s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:39<00:00,  2.46s/it]\u001b[A\n",
      "alpha: 0.54 temp: 0.12 n_drafts: 3 prompt_len: 25:  20%|██        | 2/10 [01:19<05:17, 39.68s/it]\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:36,  2.41s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:33,  2.41s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:32,  2.47s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:09<00:30,  2.52s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.48s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:14<00:24,  2.48s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:17<00:22,  2.48s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:19<00:19,  2.47s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:22<00:17,  2.45s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:24<00:14,  2.45s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:27<00:12,  2.43s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:29<00:09,  2.43s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:31<00:07,  2.46s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:34<00:04,  2.46s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:36<00:02,  2.46s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:38<00:00,  2.41s/it]\u001b[A\n",
      "alpha: 0.5800000000000001 temp: 0.1 n_drafts: 3 prompt_len: 25:  30%|███       | 3/10 [01:58<04:34, 39.20s/it]\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:37,  2.50s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:34,  2.47s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:32,  2.53s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:10<00:30,  2.52s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.51s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:15<00:25,  2.52s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:17<00:22,  2.52s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:20<00:19,  2.50s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:22<00:17,  2.48s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:24<00:14,  2.48s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:27<00:12,  2.47s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:29<00:09,  2.48s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:32<00:07,  2.51s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:35<00:05,  2.52s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:37<00:02,  2.52s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:39<00:00,  2.45s/it]\u001b[A\n",
      "alpha: 0.55 temp: 0.12 n_drafts: 3 prompt_len: 25:  40%|████      | 4/10 [02:37<03:55, 39.23s/it]             \n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:37,  2.51s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:05<00:34,  2.50s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:33,  2.55s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:10<00:30,  2.53s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.53s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:15<00:25,  2.56s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:17<00:22,  2.55s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:20<00:20,  2.53s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:22<00:17,  2.53s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:25<00:15,  2.52s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:27<00:12,  2.52s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:30<00:10,  2.53s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:32<00:07,  2.54s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:35<00:05,  2.54s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:38<00:02,  2.54s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:39<00:00,  2.48s/it]\u001b[A\n",
      "alpha: 0.45999999999999996 temp: 0.16 n_drafts: 3 prompt_len: 25:  50%|█████     | 5/10 [03:17<03:17, 39.42s/it]\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:36,  2.44s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:33,  2.41s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:32,  2.47s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:09<00:29,  2.46s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.45s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:14<00:24,  2.45s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:17<00:21,  2.44s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:19<00:19,  2.42s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:21<00:16,  2.40s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:24<00:14,  2.42s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:26<00:12,  2.41s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:29<00:09,  2.41s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:31<00:07,  2.43s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:34<00:04,  2.45s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:36<00:02,  2.44s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:38<00:00,  2.39s/it]\u001b[A\n",
      "alpha: 0.56 temp: 0.1 n_drafts: 3 prompt_len: 25:  60%|██████    | 6/10 [03:55<02:36, 39.00s/it]                \n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:34,  2.29s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:32,  2.32s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:30,  2.37s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:09<00:28,  2.40s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:11<00:26,  2.38s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:14<00:23,  2.37s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:16<00:21,  2.35s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:18<00:18,  2.33s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:21<00:16,  2.32s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:23<00:14,  2.34s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:25<00:11,  2.34s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:28<00:09,  2.32s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:30<00:07,  2.35s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:32<00:04,  2.35s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:35<00:02,  2.35s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:36<00:00,  2.30s/it]\u001b[A\n",
      "alpha: 0.6 temp: 0.08 n_drafts: 3 prompt_len: 25:  70%|███████   | 7/10 [04:32<01:54, 38.30s/it]\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:35,  2.36s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:32,  2.32s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:30,  2.37s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:09<00:28,  2.37s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:11<00:25,  2.35s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:14<00:23,  2.35s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:16<00:21,  2.35s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:18<00:18,  2.34s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:21<00:16,  2.33s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:23<00:14,  2.34s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:25<00:11,  2.33s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:28<00:09,  2.34s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:30<00:07,  2.38s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:33<00:04,  2.39s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:35<00:02,  2.39s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:37<00:00,  2.31s/it]\u001b[A\n",
      "alpha: 0.6599999999999999 temp: 0.08 n_drafts: 3 prompt_len: 25:  80%|████████  | 8/10 [05:09<01:15, 37.90s/it]\n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:35,  2.36s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:32,  2.35s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:31,  2.41s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:09<00:28,  2.40s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:11<00:26,  2.38s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:14<00:23,  2.36s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:16<00:21,  2.37s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:18<00:18,  2.35s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:21<00:16,  2.35s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:23<00:14,  2.35s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:25<00:11,  2.34s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:28<00:09,  2.33s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:30<00:07,  2.37s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:33<00:04,  2.39s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:35<00:02,  2.39s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:37<00:00,  2.32s/it]\u001b[A\n",
      "alpha: 0.64 temp: 0.08 n_drafts: 3 prompt_len: 25:  90%|█████████ | 9/10 [05:46<00:37, 37.67s/it]              \n",
      "  0%|          | 0/16 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▋         | 1/16 [00:02<00:37,  2.48s/it]\u001b[A\n",
      " 12%|█▎        | 2/16 [00:04<00:34,  2.48s/it]\u001b[A\n",
      " 19%|█▉        | 3/16 [00:07<00:32,  2.53s/it]\u001b[A\n",
      " 25%|██▌       | 4/16 [00:10<00:30,  2.52s/it]\u001b[A\n",
      " 31%|███▏      | 5/16 [00:12<00:27,  2.51s/it]\u001b[A\n",
      " 38%|███▊      | 6/16 [00:15<00:25,  2.52s/it]\u001b[A\n",
      " 44%|████▍     | 7/16 [00:17<00:22,  2.53s/it]\u001b[A\n",
      " 50%|█████     | 8/16 [00:20<00:19,  2.50s/it]\u001b[A\n",
      " 56%|█████▋    | 9/16 [00:22<00:17,  2.50s/it]\u001b[A\n",
      " 62%|██████▎   | 10/16 [00:25<00:15,  2.50s/it]\u001b[A\n",
      " 69%|██████▉   | 11/16 [00:27<00:12,  2.50s/it]\u001b[A\n",
      " 75%|███████▌  | 12/16 [00:30<00:10,  2.50s/it]\u001b[A\n",
      " 81%|████████▏ | 13/16 [00:32<00:07,  2.52s/it]\u001b[A\n",
      " 88%|████████▊ | 14/16 [00:35<00:05,  2.53s/it]\u001b[A\n",
      " 94%|█████████▍| 15/16 [00:37<00:02,  2.52s/it]\u001b[A\n",
      "100%|██████████| 16/16 [00:39<00:00,  2.46s/it]\u001b[A\n",
      "alpha: 0.33999999999999997 temp: 0.22 n_drafts: 3 prompt_len: 25: 100%|██████████| 10/10 [06:25<00:00, 38.21s/it]"
     ]
    }
   ],
   "source": [
    "rerank_results = {}\n",
    "loop = tqdm(total=10, position=0, leave=True)\n",
    "for loss, param in top_10:\n",
    "    alpha, temp, n_drafts, prompt_len, n_token_sample, sample_tokens, sample_beams, mixing_method, max_gen_len = param\n",
    "    mixed_sequences, mixed_ppl = evaluate_mixed_losses(data=dataset[:500],\n",
    "                   model=mixed_model,\n",
    "                   tokenizer=tokenizer,\n",
    "                   prompt_len=prompt_len,\n",
    "                   max_gen_len=max_gen_len,\n",
    "                   alpha=alpha,\n",
    "                   temp=temp,\n",
    "                   n_drafts=n_drafts,\n",
    "                   n_token_consider=n_token_consider,\n",
    "                   n_token_sample=n_token_sample,\n",
    "                   mixing_method=mixing_method,\n",
    "                   smoothing=smoothing,\n",
    "                   debug=False,\n",
    "                   bsz=bsz,\n",
    "                   i_weights=i_weights,\n",
    "                   i_length=i_length,\n",
    "                   ngrams=ngrams,\n",
    "                   sample_beams=sample_beams,\n",
    "                   sample_tokens=sample_tokens,\n",
    "                   marker=True)\n",
    "    rerank_results[param] = mixed_sequences\n",
    "    loop.set_description(f\"alpha: {alpha} temp: {temp} n_drafts: {n_drafts} prompt_len: {prompt_len}\")\n",
    "    loop.update(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d9b493fd-24ef-410b-9730-c2594ab7a09b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"../tuning/rr_p{prompt_len}_d{n_drafts}_mixed_7B.pkl\", \"wb\") as f:\n",
    "    pickle.dump(rerank_results, f)\n",
    "torch.save(rerank_results, f\"../tuning/rr_p{prompt_len}_d{n_drafts}_mixed_7B.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9b3ad6e5-7a62-4c7c-bc89-424eefb2839d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(7.73828125, (0.45999999999999996, 0.16, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (7.94140625, (0.5, 0.12, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.0234375, (0.33999999999999997, 0.22, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.0703125, (0.6, 0.08, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.078125, (0.5800000000000001, 0.1, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.09375, (0.6599999999999999, 0.08, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.1015625, (0.64, 0.08, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.53125, (0.55, 0.12, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.609375, (0.56, 0.1, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10)), (8.796875, (0.54, 0.12, 3, 25, 9, False, False, 'sample_new_weights_with_score', 10))]\n"
     ]
    }
   ],
   "source": [
    "with open(f\"../tuning/rr_p{prompt_len}_d{n_drafts}_mixed_7B_llama_tune.pkl\", \"rb\") as f:\n",
    "    results = pickle.load(f)\n",
    "results.sort(key=lambda tup: tup[0])  # sorts in place\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e5e2faf3-e835-406a-b797-1399c4660277",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_to_save = {\n",
    "    \"alpha\": results[0][1][0],\n",
    "    \"temp\": results[0][1][1],\n",
    "    \"n_drafts\": results[0][1][2],\n",
    "    \"prompt_len\": results[0][1][3],\n",
    "    \"n_token_sample\": results[0][1][4],\n",
    "    \"n_token_consider\": n_token_consider,\n",
    "    \"mixing_method\": mixing_method,\n",
    "    \"smoothing\": smoothing,\n",
    "    \"sample_tokens\": int(sample_tokens),\n",
    "    \"sample_beams\": int(sample_beams),\n",
    "    \"i_weights\": i_weights,\n",
    "    \"i_length\": i_length\n",
    "}\n",
    "with open(f\"../params/p{prompt_len}_d{n_drafts}_mixed.json\", \"w\") as f:\n",
    "    json.dump(param_to_save, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b12d003f-76f6-4eb9-891c-45ac5fcdee98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mixed",
   "language": "python",
   "name": "mixed"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
