{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def analyze_logits_distribution(model, tokenizer, dataset, num_bins=10):\n",
    "    token_frequencies = defaultdict(int)\n",
    "    logit_distributions = defaultdict(list)\n",
    "    \n",
    "    # Step 1: Calculate token frequencies in the dataset\n",
    "    for text in tqdm(dataset, desc=\"Calculating token frequencies\"):\n",
    "        tokens = tokenizer.encode(text)\n",
    "        for token in tokens:\n",
    "            token_frequencies[token] += 1\n",
    "    \n",
    "    # Step 2: Create frequency bins\n",
    "    sorted_tokens = sorted(token_frequencies.items(), key=lambda x: x[1], reverse=True)\n",
    "    total_tokens = len(sorted_tokens)\n",
    "    bin_size = total_tokens // num_bins\n",
    "    frequency_bins = [sorted_tokens[i:i+bin_size] for i in range(0, total_tokens, bin_size)]\n",
    "    \n",
    "    # Step 3: Analyze logit distributions for each bin\n",
    "    for bin_idx, bin_tokens in enumerate(tqdm(frequency_bins, desc=\"Analyzing logit distributions\")):\n",
    "        bin_token_ids = [token for token, _ in bin_tokens]\n",
    "        \n",
    "        for text in dataset:\n",
    "            input_ids = tokenizer.encode(text, return_tensors=\"pt\")\n",
    "            with torch.no_grad():\n",
    "                outputs = model(input_ids)\n",
    "                logits = outputs.logits[:, -1, :]  # Get logits for the last token\n",
    "                \n",
    "                for token_id in bin_token_ids:\n",
    "                    logit_distributions[bin_idx].append(logits[0, token_id].item())\n",
    "    \n",
    "    return logit_distributions\n",
    "\n",
    "def plot_side_by_side_logit_distributions(logit_distributions_wd00, logit_distributions_wd01, num_bins):\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharey=True)\n",
    "    \n",
    "    for ax, (logit_distributions, title) in zip(axes, [\n",
    "        (logit_distributions_wd00, \"Weight Decay 0.0\"),\n",
    "        (logit_distributions_wd01, \"Weight Decay 0.1\")\n",
    "    ]):\n",
    "        positions = []\n",
    "        distributions = []\n",
    "        \n",
    "        for bin_idx in range(num_bins):\n",
    "            positions.append(bin_idx)\n",
    "            distributions.append(logit_distributions[bin_idx])\n",
    "        \n",
    "        parts = ax.violinplot(distributions, positions=positions, showmeans=False, showextrema=True, showmedians=True)\n",
    "        \n",
    "        # Customize violin plot colors\n",
    "        for pc in parts['bodies']:\n",
    "            pc.set_facecolor('#D43F3A')\n",
    "            pc.set_edgecolor('black')\n",
    "            pc.set_alpha(0.7)\n",
    "        \n",
    "        for partname in ('cbars', 'cmins', 'cmaxes', 'cmedians'):\n",
    "            vp = parts[partname]\n",
    "            vp.set_edgecolor('black')\n",
    "            vp.set_linewidth(1)\n",
    "        \n",
    "        ax.set_xlabel(\"Frequency Bins (High to Low)\")\n",
    "        ax.set_ylabel(\"Logit Values\")\n",
    "        ax.set_title(title)\n",
    "        ax.set_xticks(range(num_bins))\n",
    "        ax.set_xticklabels([f\"Bin {i+1}\" for i in range(num_bins)])\n",
    "    \n",
    "    plt.suptitle(\"Logit Distributions Across Frequency Bins - Weight Decay Comparison\", fontsize=16)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# Usage example\n",
    "num_bins = 10\n",
    "logit_distributions_wd00 = analyze_logits_distribution(model_wd00, tokenizer, dataset, num_bins)\n",
    "logit_distributions_wd01 = analyze_logits_distribution(model_wd01, tokenizer, dataset, num_bins)\n",
    "\n",
    "plot_side_by_side_logit_distributions(logit_distributions_wd00, logit_distributions_wd01, num_bins)\n",
    "\n",
    "# Optional: Statistical comparison\n",
    "def compare_distributions(dist1, dist2):\n",
    "    from scipy import stats\n",
    "    \n",
    "    results = []\n",
    "    for bin_idx in range(num_bins):\n",
    "        t_stat, p_value = stats.ttest_ind(dist1[bin_idx], dist2[bin_idx])\n",
    "        results.append((bin_idx, t_stat, p_value))\n",
    "    \n",
    "    return results\n",
    "\n",
    "statistical_comparison = compare_distributions(logit_distributions_wd00, logit_distributions_wd01)\n",
    "for bin_idx, t_stat, p_value in statistical_comparison:\n",
    "    print(f\"Bin {bin_idx + 1}: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
