{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import polars as pl\n",
    "from scipy import stats\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import wandb\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = 'results/separated'\n",
    "exp = 'apple-small-imdb'\n",
    "df = pl.read_parquet(f'{base}/{exp}.parquet')\n",
    "df = df.head(500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_exp = 0  # e.g., starting from 3^0\n",
    "max_exp = 10  # e.g., ending at 3^10 (adjust as needed)\n",
    "bins = 3 ** np.arange(min_exp, max_exp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_bin_range(bin_num):\n",
    "    exp_start = bin_num\n",
    "    exp_end = bin_num + 1\n",
    "    return fr\"$3^{{{exp_start}}} - 3^{{{exp_end}}}$\"\n",
    "\n",
    "def assign_bin(freq):\n",
    "    return np.digitize(freq, bins) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_learning_speed(losses):\n",
    "    # learning speed is inverse of AUC\n",
    "    # faster learning = smaller area\n",
    "    losses = np.array(losses, dtype=np.float32)\n",
    "    min_loss = losses.min()\n",
    "    range_losses = np.ptp(losses)\n",
    "    if range_losses == 0: return 0\n",
    "    normalized_losses = (losses - min_loss) / range_losses\n",
    "    auc = np.trapz(normalized_losses, dx=1)\n",
    "    return 1 - (auc / len(losses))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# try to fetch the speed for each of them!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (5, 12)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>seed</th><th>wd</th><th>lr</th><th>iter</th><th>epoch</th><th>loss</th><th>grad_norm</th><th>token_id</th><th>tok_loss</th><th>tok_acc</th><th>tok_freq</th><th>freq_bin</th></tr><tr><td>i32</td><td>f32</td><td>f32</td><td>i32</td><td>f32</td><td>f32</td><td>f32</td><td>i64</td><td>f32</td><td>f32</td><td>i64</td><td>i64</td></tr></thead><tbody><tr><td>1</td><td>2.0</td><td>0.00005</td><td>100</td><td>0.255754</td><td>7.2086</td><td>1.884823</td><td>2742</td><td>10.363544</td><td>0.0</td><td>27</td><td>3</td></tr><tr><td>1</td><td>2.0</td><td>0.00005</td><td>100</td><td>0.255754</td><td>7.2086</td><td>1.884823</td><td>149</td><td>4.259922</td><td>0.099795</td><td>12175</td><td>8</td></tr><tr><td>1</td><td>2.0</td><td>0.00005</td><td>100</td><td>0.255754</td><td>7.2086</td><td>1.884823</td><td>1638</td><td>9.690016</td><td>0.0</td><td>170</td><td>4</td></tr><tr><td>1</td><td>2.0</td><td>0.00005</td><td>100</td><td>0.255754</td><td>7.2086</td><td>1.884823</td><td>2204</td><td>9.897054</td><td>0.0</td><td>105</td><td>4</td></tr><tr><td>1</td><td>2.0</td><td>0.00005</td><td>100</td><td>0.255754</td><td>7.2086</td><td>1.884823</td><td>139</td><td>4.247193</td><td>0.031617</td><td>14391</td><td>8</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (5, 12)\n",
       "┌──────┬─────┬─────────┬──────┬───┬───────────┬──────────┬──────────┬──────────┐\n",
       "│ seed ┆ wd  ┆ lr      ┆ iter ┆ … ┆ tok_loss  ┆ tok_acc  ┆ tok_freq ┆ freq_bin │\n",
       "│ ---  ┆ --- ┆ ---     ┆ ---  ┆   ┆ ---       ┆ ---      ┆ ---      ┆ ---      │\n",
       "│ i32  ┆ f32 ┆ f32     ┆ i32  ┆   ┆ f32       ┆ f32      ┆ i64      ┆ i64      │\n",
       "╞══════╪═════╪═════════╪══════╪═══╪═══════════╪══════════╪══════════╪══════════╡\n",
       "│ 1    ┆ 2.0 ┆ 0.00005 ┆ 100  ┆ … ┆ 10.363544 ┆ 0.0      ┆ 27       ┆ 3        │\n",
       "│ 1    ┆ 2.0 ┆ 0.00005 ┆ 100  ┆ … ┆ 4.259922  ┆ 0.099795 ┆ 12175    ┆ 8        │\n",
       "│ 1    ┆ 2.0 ┆ 0.00005 ┆ 100  ┆ … ┆ 9.690016  ┆ 0.0      ┆ 170      ┆ 4        │\n",
       "│ 1    ┆ 2.0 ┆ 0.00005 ┆ 100  ┆ … ┆ 9.897054  ┆ 0.0      ┆ 105      ┆ 4        │\n",
       "│ 1    ┆ 2.0 ┆ 0.00005 ┆ 100  ┆ … ┆ 4.247193  ┆ 0.031617 ┆ 14391    ┆ 8        │\n",
       "└──────┴─────┴─────────┴──────┴───┴───────────┴──────────┴──────────┴──────────┘"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = df.with_columns([\n",
    "    pl.col('tok_freq').map_elements(assign_bin, return_dtype=pl.Int64).alias('freq_bin')\n",
    "])\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processing 500/500...\r"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "unique_combinations = df.select(['seed', 'wd', 'token_id']).unique()\n",
    "l = len(unique_combinations) \n",
    "for idx, row in enumerate(unique_combinations.rows()):\n",
    "    print(f'processing {idx+1}/{l}...', end='\\r')\n",
    "    seed, wd, token_id = row\n",
    "    group_df = df.filter(\n",
    "        (pl.col('seed') == seed) & (pl.col('wd') == wd) & (pl.col('token_id') == token_id)\n",
    "    )\n",
    "    losses = group_df['tok_loss'].to_numpy()\n",
    "    avg_freq = group_df['tok_freq'].mean()\n",
    "    speed = calculate_learning_speed(losses)\n",
    "    results.append([seed, wd, token_id, speed, avg_freq])\n",
    "\n",
    "df_learning_speed = pl.DataFrame(\n",
    "    results,\n",
    "    schema=['seed', 'wd', 'token_id', 'speed', 'avg_freq'],\n",
    "    orient=\"row\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f'{base}/learning-speed', exist_ok=True)\n",
    "df_learning_speed.write_parquet(f'{base}/learning-speed/{exp}.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-wd-fairness-eaiv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
