{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bef39784-bac9-4471-8eae-2b41ef9224db",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import argparse\n",
    "\n",
    "import json\n",
    "import glob\n",
    "import os\n",
    "\n",
    "from open_lm.params import parse_args\n",
    "from open_lm.model import test_perplexity_model\n",
    "\n",
    "from transformers import GPTNeoXTokenizerFast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a696fa2d-ac83-4c3c-8fc4-f87c978d6728",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parse_args([])\n",
    "args.model = \"open_lm_160m\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27fb666-d2a7-464b-b8cf-fd75d131855b",
   "metadata": {},
   "outputs": [],
   "source": [
    "############################ SET THOSE VALUES #################################\n",
    "\n",
    "# Set the path for the pretrained model to be evaluated\n",
    "args.classif_model_path = \"pretrained_models/C4.pt\"\n",
    "\n",
    "#Set the device\n",
    "device = 'cuda'\n",
    "\n",
    "# Set the directory containing the evaluation .jsonl files (could be one or more files, it will iterate over all)\n",
    "input_dir = \"cross_dataset\"\n",
    "#input_dir = \"paloma\" \n",
    "#input_dir = \"wikitext_103\"\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e8d645a-801c-4ecc-83e1-a37b87eeb6eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model and move to device\n",
    "model = test_perplexity_model(args)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "# Automatically find all .jsonl files in the directory\n",
    "input_files = sorted(glob.glob(os.path.join(input_dir, \"*.jsonl\")))\n",
    "\n",
    "# Load tokenizer\n",
    "tokenizer = GPTNeoXTokenizerFast.from_pretrained('EleutherAI/gpt-neox-20b')\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# Parameters\n",
    "batch_size = 32\n",
    "max_length = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb95e8f0-d522-4a10-a9db-522d3419194e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_perplexity(model, tokenizer, texts, batch_size, max_length, device):\n",
    "    model.to(device).eval()\n",
    "    total_log_likelihood = 0.0\n",
    "    total_tokens = 0\n",
    "\n",
    "    loss_fct = torch.nn.CrossEntropyLoss(\n",
    "        ignore_index=tokenizer.pad_token_id,\n",
    "        reduction='sum'\n",
    "    )\n",
    "\n",
    "    for i in tqdm(range(0, len(texts), batch_size), desc=\"Batches\"):\n",
    "        batch = texts[i : i + batch_size]\n",
    "        batch_texts = [t[\"text\"] for t in batch]\n",
    "\n",
    "        enc = tokenizer(\n",
    "            batch_texts,\n",
    "            padding=True,\n",
    "            truncation=True,\n",
    "            max_length=max_length,\n",
    "            return_tensors=\"pt\",\n",
    "        ).to(device)\n",
    "\n",
    "        input_ids = enc[\"input_ids\"]\n",
    "        attention_mask = enc[\"attention_mask\"]\n",
    "\n",
    "        targets = input_ids.clone()\n",
    "        targets[:, :-1] = input_ids[:, 1:]\n",
    "        targets[:, -1] = tokenizer.pad_token_id\n",
    "\n",
    "        with torch.no_grad():\n",
    "            logits = model(input_ids)[0]\n",
    "\n",
    "        B, L, V = logits.size()\n",
    "        logits_flat = logits.view(-1, V)\n",
    "        targets_flat = targets.view(-1)\n",
    "\n",
    "        loss_sum = loss_fct(logits_flat, targets_flat).item()\n",
    "        non_pad = (targets_flat != tokenizer.pad_token_id).sum().item()\n",
    "\n",
    "        total_log_likelihood += -loss_sum\n",
    "        total_tokens += non_pad\n",
    "\n",
    "    if total_tokens == 0:\n",
    "        return float(\"inf\"), 0.0, 0\n",
    "\n",
    "    ppl = math.exp(-total_log_likelihood / total_tokens)\n",
    "    return ppl, total_log_likelihood, total_tokens\n",
    "\n",
    "# Compute perplexity\n",
    "overall_ll = 0.0\n",
    "overall_tok = 0\n",
    "\n",
    "print(\"Perplexity by file:\")\n",
    "for fname in input_files:\n",
    "    with open(fname, 'r') as f:\n",
    "        texts = [json.loads(line) for line in f]\n",
    "\n",
    "    ppl, ll, tok = calculate_perplexity(\n",
    "        model, tokenizer, texts,\n",
    "        batch_size=batch_size,\n",
    "        max_length=max_length,\n",
    "        device=device\n",
    "    )\n",
    "    print(f\"  {os.path.basename(fname)}: {ppl:.2f}\")\n",
    "    overall_ll += ll\n",
    "    overall_tok += tok\n",
    "\n",
    "if overall_tok > 0:\n",
    "    overall_ppl = math.exp(-overall_ll / overall_tok)\n",
    "    print(f\"\\nOverall Perplexity: {overall_ppl:.2f}\")\n",
    "else:\n",
    "    print(\"\\nNo tokens processed; cannot compute overall perplexity.\")\n",
    "\n",
    "print(f\"Total tokens: {overall_tok:,}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
