{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load the tokenizer from files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"safemodels/openelm-270m-imdb-wd1.0-seed1\"\n",
    "\n",
    "# Check if all required files exist\n",
    "required_files = [\"config.json\", \"model.safetensors\", \"tokenizer.json\", \"tokenizer_config.json\"]\n",
    "for file in required_files:\n",
    "    if not os.path.exists(os.path.join(model_path, file)):\n",
    "        raise FileNotFoundError(f\"Required file {file} not found in {model_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the tokenizer\n",
    "try:\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n",
    "    print(\"Tokenizer loaded successfully.\")\n",
    "except Exception as e:\n",
    "    print(f\"Error loading tokenizer: {e}\")\n",
    "    print(\"Attempting to load tokenizer from config...\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True, use_fast=False)\n",
    "    print(\"Tokenizer loaded from config.\")\n",
    "\n",
    "# Test the tokenizer\n",
    "test_text = \"The movie\"\n",
    "encoded = tokenizer(test_text)\n",
    "decoded = tokenizer.decode(encoded['input_ids'])\n",
    "print(f\"Test encoding/decoding: '{test_text}' -> '{decoded}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load the model from files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoConfig\n",
    "from safetensors.torch import load_file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"safemodels/openelm-270m-imdb-wd0.0-seed1\"\n",
    "safetensors_path = f\"{model_path}/model.safetensors\"\n",
    "config_path = os.path.join(model_path, \"config.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = AutoConfig.from_pretrained(config_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = \"apple/OpenELM-270M\"\n",
    "model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the weights from the local safetensors file\n",
    "state_dict = load_file(safetensors_path)\n",
    "model.load_state_dict(state_dict)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# generate sample output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "\n",
    "for i in range(100):\n",
    "    print(f\"Generating sample {i+1}/100\", end=\"\\r\")\n",
    "    text = \"The movie\"\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "    inputs.pop('token_type_ids', None)\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "            **inputs,\n",
    "            max_length=100,\n",
    "            num_return_sequences=1,\n",
    "            temperature=0.7,\n",
    "            top_k=50,\n",
    "            top_p=0.95,\n",
    "            do_sample=True\n",
    "        )\n",
    "\n",
    "    # Decode the output\n",
    "    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "    samples.append(generated_text)\n",
    "\n",
    "# save the samples to a file\n",
    "import json\n",
    "with open(\"samples-wd10.json\", \"w\") as f:\n",
    "    json.dump(samples, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# reload both samples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"samples-wd00.json\", \"r\") as f:\n",
    "    samples_wd00 = json.load(f)\n",
    "with open(\"samples-wd10.json\", \"r\") as f:\n",
    "    samples_wd10 = json.load(f)\n",
    "\n",
    "print(f\"Loaded {len(samples_wd00)} samples from samples-wd00.json\")\n",
    "print(f\"Loaded {len(samples_wd10)} samples from samples-wd10.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create statistics of dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import llm_fairness\n",
    "import collections\n",
    "dataset = llm_fairness.data.from_name(\"imdb\")\n",
    "dataset = llm_fairness.data.from_name(\"imdb\")[\"train\"]\n",
    "tokenizer = llm_fairness.tokenizer.from_data(dataset, variant=\"BPE\", vocab_size=32000)\n",
    "dataset = dataset.map(\n",
    "    lambda x: tokenizer(\n",
    "        x[\"text\"],\n",
    "        truncation=True,\n",
    "        padding=\"max_length\",\n",
    "        max_length=128,\n",
    "    ),\n",
    "    batched=True,\n",
    "    remove_columns=dataset.column_names,\n",
    ")\n",
    "\n",
    "tok2id = {tokenizer.decode(tid): tid for _, tid in tokenizer._tokenizer.get_vocab().items()}\n",
    "id2tok = {tid : tokenizer.decode(tid) for _, tid in tokenizer._tokenizer.get_vocab().items()}\n",
    "\n",
    "tokens = []\n",
    "for seq in dataset: tokens.extend(seq['input_ids'])\n",
    "token_counts = collections.Counter(tokens)\n",
    "unique_tokens = set(tokens)\n",
    "\n",
    "print(f\"we trained a bpe with target vocab size: {32000}\")\n",
    "print(f\"our actual vocab size is {len(tok2id)}\")\n",
    "print(f\"we have {len(unique_tokens)} unique tokens\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import Counter\n",
    "freqs = list(token_counts.values())\n",
    "min_exp = 0  # e.g., starting from 3^0\n",
    "max_exp = 10  # e.g., ending at 3^10 (adjust as needed)\n",
    "bins = 3 ** np.arange(min_exp, max_exp)\n",
    "\n",
    "def get_bin_range(bin_num):\n",
    "    exp_start = bin_num\n",
    "    exp_end = bin_num + 1\n",
    "    return fr\"$3^{{{exp_start}}} - 3^{{{exp_end}}}$\"\n",
    "\n",
    "def assign_bin(freq):\n",
    "    return np.digitize(freq, bins) - 1\n",
    "\n",
    "bin_assign = assign_bin(list(freqs))\n",
    "bin_assign_counter = Counter(bin_assign)\n",
    "keys = np.array(list(bin_assign_counter.keys()))\n",
    "values = np.array(list(bin_assign_counter.values()))\n",
    "sorted_indices = np.argsort(keys)\n",
    "sorted_keys = keys[sorted_indices]\n",
    "sorted_values = values[sorted_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_id_to_bin = {token_id: assign_bin(count) for token_id, count in token_counts.items()}\n",
    "token_id_to_bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "bin_counts = Counter(token_id_to_bin.values())\n",
    "sorted_bin_counts = sorted(bin_counts.items())\n",
    "\n",
    "print(\"Bin ID | Count\")\n",
    "print(\"-------|------\")\n",
    "for bin_id, count in sorted_bin_counts:\n",
    "    print(f\"{bin_id:6d} | {count:5d}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the statictics from the samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import numpy as np\n",
    "\n",
    "def tokenize_and_bin(samples):\n",
    "    all_token_bins = []\n",
    "    for sample in samples:\n",
    "        tokens = tokenizer(sample)['input_ids']\n",
    "        token_bins = [token_id_to_bin.get(token, -1) for token in tokens]  # -1 for tokens not in our vocabulary\n",
    "        all_token_bins.extend(token_bins)\n",
    "    return Counter(all_token_bins)\n",
    "\n",
    "# Tokenize and bin for wd00 samples\n",
    "wd00_bin_counts = tokenize_and_bin(samples_wd00)\n",
    "\n",
    "# Tokenize and bin for wd10 samples\n",
    "wd10_bin_counts = tokenize_and_bin(samples_wd10)\n",
    "\n",
    "# Prepare data for plotting\n",
    "bin_ids = sorted(set(wd00_bin_counts.keys()) | set(wd10_bin_counts.keys()))\n",
    "wd00_counts = [wd00_bin_counts.get(bin_id, 0) for bin_id in bin_ids]\n",
    "wd10_counts = [wd10_bin_counts.get(bin_id, 0) for bin_id in bin_ids]\n",
    "\n",
    "# Print the results\n",
    "print(\"Bin ID | WD00 Count | WD10 Count\")\n",
    "print(\"-------|------------|------------\")\n",
    "for bin_id, wd00_count, wd10_count in zip(bin_ids, wd00_counts, wd10_counts):\n",
    "    print(f\"{bin_id:6d} | {wd00_count:10d} | {wd10_count:10d}\")\n",
    "\n",
    "# Data for plotting\n",
    "x = np.arange(len(bin_ids))\n",
    "width = 0.35\n",
    "\n",
    "# Plotting\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 8))\n",
    "rects1 = ax.bar(x - width/2, wd00_counts, width, label='WD00')\n",
    "rects2 = ax.bar(x + width/2, wd10_counts, width, label='WD10')\n",
    "\n",
    "ax.set_xlabel('Bin ID')\n",
    "ax.set_ylabel('Token Count')\n",
    "ax.set_title('Distribution of Tokens in Bins for WD00 and WD10')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels([get_bin_range(bin_id) for bin_id in bin_ids], rotation=45, ha='right')\n",
    "ax.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import numpy as np\n",
    "\n",
    "def tokenize_and_bin(samples):\n",
    "    all_token_bins = []\n",
    "    for sample in samples:\n",
    "        tokens = tokenizer(sample)['input_ids']\n",
    "        token_bins = [token_id_to_bin.get(token, -1) for token in tokens]\n",
    "        all_token_bins.extend(token_bins)\n",
    "    return Counter(all_token_bins)\n",
    "\n",
    "# Tokenize and bin for wd00 and wd10 samples\n",
    "wd00_bin_counts = tokenize_and_bin(samples_wd00)\n",
    "wd10_bin_counts = tokenize_and_bin(samples_wd10)\n",
    "\n",
    "# Prepare data for plotting\n",
    "bin_ids = sorted(set(wd00_bin_counts.keys()) | set(wd10_bin_counts.keys()))\n",
    "\n",
    "# Calculate total counts for normalization\n",
    "wd00_total = sum(wd00_bin_counts.values())\n",
    "wd10_total = sum(wd10_bin_counts.values())\n",
    "\n",
    "# Normalize the counts\n",
    "wd00_normalized = [wd00_bin_counts.get(bin_id, 0) / wd00_total for bin_id in bin_ids]\n",
    "wd10_normalized = [wd10_bin_counts.get(bin_id, 0) / wd10_total for bin_id in bin_ids]\n",
    "\n",
    "# Print the normalized results\n",
    "print(\"Bin ID | WD00 Normalized | WD10 Normalized\")\n",
    "print(\"-------|-----------------|----------------\")\n",
    "for bin_id, wd00_norm, wd10_norm in zip(bin_ids, wd00_normalized, wd10_normalized):\n",
    "    print(f\"{bin_id:6d} | {wd00_norm:15.4f} | {wd10_norm:15.4f}\")\n",
    "\n",
    "# Data for plotting\n",
    "x = np.arange(len(bin_ids))\n",
    "width = 0.35\n",
    "\n",
    "# Plotting\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 8))\n",
    "rects1 = ax.bar(x - width/2, wd00_normalized, width, label='WD00')\n",
    "rects2 = ax.bar(x + width/2, wd10_normalized, width, label='WD10')\n",
    "\n",
    "ax.set_xlabel('Bin ID')\n",
    "ax.set_ylabel('Normalized Token Frequency')\n",
    "ax.set_title('Normalized Distribution of Tokens in Bins for WD00 and WD10')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels([get_bin_range(bin_id) for bin_id in bin_ids], rotation=45, ha='right')\n",
    "ax.legend()\n",
    "\n",
    "# Set y-axis to logarithmic scale\n",
    "ax.set_yscale('log')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Calculate and print the difference\n",
    "print(\"\\nDifference (WD10 - WD00):\")\n",
    "for bin_id, wd00_norm, wd10_norm in zip(bin_ids, wd00_normalized, wd10_normalized):\n",
    "    diff = wd10_norm - wd00_norm\n",
    "    print(f\"Bin {bin_id}: {diff:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-wd-fairness-eaiv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
