{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"Qwen/Qwen2.5-Coder-7B-Instruct\"\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "chat_text = \"How can I sort a list of strings in Python?\"\n",
    "chat_text = \"Show me how to implement a toy version of a relational database. Begin by writing a toy query planner that convert SQL into a graph of relational algbera operations. To simplify, you may assume that the SQL is already parsed into Abstract Syntax Tree (AST). You also only need to implement basic \\\"select\\\" that may have some columns only, with \\\"where\\\" clause, but no sort by or pagination.\"\n",
    "chat_template = [{\"role\": \"user\", \"content\": chat_text}]\n",
    "\n",
    "formatted_text = tokenizer.apply_chat_template(chat_template, tokenize=False)\n",
    "\n",
    "# formatted_text = formatted_text + tokenizer.eos_token + formatted_text + tokenizer.eos_token\n",
    "\n",
    "formatted_text = \"<|im_end|>\".join(formatted_text.split(\"<|im_end|>\")[1:])\n",
    "\n",
    "print(formatted_text)\n",
    "chat_tokens = tokenizer(formatted_text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
    "print(chat_tokens)\n",
    "print(len(chat_tokens[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrain_text = tokenizer.eos_token + \"\"\"Iran’s Wild Card for Defense Stocks In 2012 … Dividends Hang in the Balance (ATK, GD, LLL, LMT, RTN)December 28, 2011 by Jon C. Ogg\n",
    "It is no secret that Iran is a saber-rattling nation. The country wants to be relevant on the global stage so much that it keeps up with its nuclear ambitions regardless of global trading sanctions and regardless of efforts from the Western nations trying to stop it. And now the big news is not the nuclear front, but an Iranian minister claiming that Iran could effectively block the flow of traffic through the Gulf of Hormuz easier than drinking a glass of water.\n",
    "In the age of austerity and military budgets being slashed to deal with deficits, Iran has a chance of turning 2012 accidentally into the year of defense stocks. Alliant Techsystems Inc. (NYSE: ATK), General Dynamics Corp. (NYSE: GD), L-3 Communications Holdings Inc. (NYSE: LLL), Lockheed Martin Corporation (NYSE: LMT) and Raytheon Co. (NYSE: RTN) could all hang in the balance. With operations all but gone in Iraq and with the trend in Afghanistan being one of leaving, Iran is the obvious wild card.\"\"\" + tokenizer.eos_token + \"\"\"L-3 Communications Holdings Inc. (NYSE: LLL) trades at $66.85 and the 52-week trading range is $58.30 to $88.55. Thomson Reuters has a consensus price target of $71.64, implying upside of only about 7%. L-3 yields about 2.7% in its dividend and shares are up about 5% from the Thanksgiving break. Lockheed Martin Corp. (NYSE: LMT) trades at $81.25 and the 52-week trading range is $66.36 to $82.43. Thomson Reuters has a consensus price target of $80.12, implying that the stock is above a full-value price. Its dividend yield is quite high at about 5% and shares are up almost 5% since its $1.00 dividend was reflected in the stock in late November. Raytheon Co. (NYSE: RTN) trades at $48.50 and the 52-week trading range is $38.35 to $53.12. Thomson Reuters has a consensus price target of $49.53, implying upside of only about 2%. Its dividend yield is currently about 3.7% and shares are up about 14% since the Thanksgiving break.\"\"\"\n",
    "pretrain_tokens = tokenizer(pretrain_text, return_tensors=\"pt\").to(device)[\"input_ids\"]\n",
    "print(pretrain_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokens = pretrain_tokens\n",
    "# tokens = chat_tokens\n",
    "# input_ids = tokens\n",
    "# labels = input_ids.clone()\n",
    "\n",
    "# torch.set_grad_enabled(False)\n",
    "\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     outputs = model(input_ids=input_ids, labels=labels)\n",
    "#     loss = outputs.loss\n",
    "#     perplexity = torch.exp(loss).item()\n",
    "\n",
    "# print(f\"Cross entropy loss: {loss.item():.4f}\")\n",
    "# print(f\"Perplexity: {perplexity:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = pretrain_tokens\n",
    "tokens = chat_tokens\n",
    "input_ids = tokens\n",
    "labels = input_ids.clone()\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(input_ids=input_ids, labels=labels)\n",
    "    logits = outputs.logits\n",
    "\n",
    "shifted_logits = logits[:, :-1, :].squeeze(0)\n",
    "shifted_labels = labels[:, 1:].squeeze(0)\n",
    "\n",
    "shifted_logits = shifted_logits[-85:]\n",
    "shifted_labels = shifted_labels[-85:]\n",
    "\n",
    "log_probs = torch.nn.functional.log_softmax(shifted_logits, dim=-1)\n",
    "token_log_probs = log_probs[torch.arange(shifted_labels.size(0)), shifted_labels]\n",
    "token_losses = -token_log_probs\n",
    "token_perplexities = torch.exp(token_losses)\n",
    "\n",
    "print(f\"Average token loss: {token_losses.mean().item():.4f}\")\n",
    "print(f\"Average token perplexity: {token_perplexities.mean().item():.4f}\")\n",
    "\n",
    "# print_tokens = tokenizer.convert_ids_to_tokens(shifted_labels.tolist())\n",
    "\n",
    "# plt.figure(figsize=(14, 5))\n",
    "# plt.plot(token_losses.tolist(), marker='o')\n",
    "# plt.xticks(ticks=range(len(print_tokens)), labels=print_tokens, rotation=45)\n",
    "# plt.ylabel(\"Cross-Entropy Loss\")\n",
    "# plt.title(\"Token-wise Loss\")\n",
    "# plt.grid(True)\n",
    "# plt.tight_layout()\n",
    "# plt.show()\n",
    "\n",
    "# plt.figure(figsize=(14, 5))\n",
    "# plt.plot(token_perplexities.tolist(), marker='o', color='orange')\n",
    "# plt.xticks(ticks=range(len(print_tokens)), labels=print_tokens, rotation=45)\n",
    "# plt.ylabel(\"Perplexity\")\n",
    "# plt.title(\"Token-wise Perplexity\")\n",
    "# plt.grid(True)\n",
    "# plt.tight_layout()\n",
    "# plt.yscale('log')\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
