{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7be1c9d-1070-45bc-8ca3-83a6c77da5bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import polars as pl\n",
    "from scipy import stats\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9912c664-a105-4fb0-acb4-0dc7e6cb0b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import wandb\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40f0bf54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "_ = load_dotenv(dotenv_path=\"vars\") # load wandb api key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7334fa63",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "project = \"ANONYMOUS\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "730c199c-2ffc-4f9c-baab-ccef665f3369",
   "metadata": {},
   "source": [
    "# helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277e191c-cb2c-4493-bb6f-a96e7af4a200",
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_df_size_in_mb(df):\n",
    "    memory_usage = df.estimated_size()\n",
    "    memory_usage_mb = memory_usage / (1024 ** 2)\n",
    "    # print(f\"Estimated memory usage: {memory_usage_mb:.2f} MB\")\n",
    "    return memory_usage_mb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4983454-b01a-4ff8-98b2-e0ca6012ebfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def try_loading_from_file(filename):\n",
    "    try:\n",
    "        df = pl.read_parquet(filename)\n",
    "        print(\"dataFrame loaded successfully from file\")\n",
    "        return df, None\n",
    "    except Exception as e:\n",
    "        return None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "615b9bcc-9d29-4c39-a11c-353c57ff889f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explode_token_metrics(row, label='token_metrics'):\n",
    "    token_metrics = json.loads(row[label])\n",
    "    for token_id, metrics in token_metrics.items():\n",
    "        yield {\n",
    "            'rid': row['rid'],\n",
    "            'seed': row['seed'],\n",
    "            'wd': row['wd'],\n",
    "            'lr': row['lr'],\n",
    "            'iter': row['iter'],\n",
    "            'epoch': row['epoch'],\n",
    "            'loss': row['loss'],\n",
    "            'grad_norm': row['grad_norm'],\n",
    "            'token_id': int(token_id),\n",
    "            'tok_loss': metrics['loss'] / metrics['total'],\n",
    "            'tok_acc': metrics['correct'] / metrics['total'],\n",
    "            'tok_freq': metrics['total']\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26b33c98-8229-458b-84b7-b146c66a5e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "def from_wandb_to_df(run_id, wd, seed):\n",
    "    # first load the run into a polars.df\n",
    "    run = api.run(f\"{project}/{run_id}\")\n",
    "    history = run.history()\n",
    "    df = pl.DataFrame(history)\n",
    "\n",
    "    # assert weight decay and seed match config.\n",
    "    wd_cfg = float(round(dict(run.config)['args.weight_decay'], 3))\n",
    "    seed_cfg = int(dict(run.config)['args.seed'])\n",
    "    assert abs(wd_cfg - wd) < 1e-5, f\"weight decay {wd_cfg} does not match expected {wd}\"\n",
    "    assert seed_cfg == seed, f\"seed {seed_cfg} does not match expected {seed}\"\n",
    "\n",
    "    # if the column \"eval_token_metrics\" exists, pop the last row.\n",
    "    if 'eval_token_metrics' in df.columns:\n",
    "        row = df.row(df.height - 1, named=True)\n",
    "        df = df.slice(0, df.height - 1)\n",
    "\n",
    "    # add wd and seed to dataframe.\n",
    "    df = df.with_columns([\n",
    "        pl.lit(run_id).alias('rid'),\n",
    "        pl.lit(wd).alias('wd'),\n",
    "        pl.lit(seed).alias('seed'),\n",
    "    ])\n",
    "\n",
    "    # rename stuff for practical reasons.\n",
    "    df = df.rename({\n",
    "        'train/grad_norm': 'grad_norm',\n",
    "        'train/loss': 'loss',\n",
    "        'train/learning_rate': 'lr',\n",
    "        'train/epoch': 'epoch',\n",
    "        'train/global_step': 'iter'\n",
    "    })\n",
    "\n",
    "    # drop, select and order some stuff.\n",
    "    df = df.drop(['_runtime', '_timestamp'])\n",
    "    df = df.select(['_step', 'rid', 'seed', 'wd', 'lr', 'step', 'epoch', 'loss', 'grad_norm', 'token_metrics'])\n",
    "    df = df.sort(['_step', 'seed', 'wd', 'lr', 'step'])\n",
    "    if df.height > 200: df = df.slice(0, df.height - 2)\n",
    "    if df.height % 2 != 0: df = df.slice(0, df.height - 1)\n",
    "\n",
    "    # merge the alternating logs recs.\n",
    "    even_df = df.filter(pl.col('_step') % 2 == 0).drop('step').drop('token_metrics')\n",
    "    odd_df = df.filter(pl.col('_step') % 2 != 0).select(['step', 'token_metrics'])\n",
    "    merged_df = pl.concat([even_df, odd_df], how='horizontal')\n",
    "    df = merged_df.drop('_step')\n",
    "\n",
    "    # rename and select some stuff again.\n",
    "    df = df.rename({\n",
    "        'step': 'iter'\n",
    "    }).with_columns([\n",
    "        pl.col('iter').cast(pl.Int64)\n",
    "    ])\n",
    "    df = df.select(['rid', 'seed', 'wd', 'lr', 'iter', 'epoch', 'loss', 'grad_norm', 'token_metrics'])\n",
    "\n",
    "    # explode token metrics entries!\n",
    "    exploded_data = [\n",
    "        row for df_row in df.iter_rows(named=True) for row in explode_token_metrics(df_row)\n",
    "    ]\n",
    "    df = pl.DataFrame(exploded_data)\n",
    "\n",
    "    # cast everything\n",
    "    df = df.with_columns([\n",
    "        pl.col(\"rid\"),\n",
    "        pl.col(\"seed\").cast(pl.Int32),\n",
    "        pl.col(\"wd\").cast(pl.Float32),\n",
    "        pl.col(\"lr\").cast(pl.Float32),\n",
    "        pl.col(\"iter\").cast(pl.Int32),\n",
    "        pl.col(\"epoch\").cast(pl.Float32),\n",
    "        pl.col(\"loss\").cast(pl.Float32),\n",
    "        pl.col(\"grad_norm\").cast(pl.Float32),\n",
    "        pl.col(\"token_id\").cast(pl.Int64),\n",
    "        pl.col(\"tok_loss\").cast(pl.Float32),\n",
    "        pl.col(\"tok_acc\").cast(pl.Float32),\n",
    "        pl.col(\"tok_freq\").cast(pl.Int64)\n",
    "    ])\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1890d38-4717-49c2-af29-b8fb13791b6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_all_runs(runs, filename=None, save=True, load_from_file=True):\n",
    "    df, meta = try_loading_from_file(filename)\n",
    "    if df is not None and load_from_file:\n",
    "        return df, meta\n",
    "    print(\"proceeding to fetch data from wandb runs...\")\n",
    "    l = len(runs)\n",
    "    dfs = []\n",
    "    meta = []\n",
    "    for i, run_id in enumerate(runs):\n",
    "        wd, seed = float(runs[run_id]['wd']), runs[run_id]['seed']\n",
    "        print(f\"fetching run {run_id} ({i+1}/{l}) wd={wd} seed={seed}\".ljust(50), end=\"\\r\")\n",
    "        df = from_wandb_to_df(run_id, wd, seed)\n",
    "        dfs.append(df)\n",
    "        meta.append((len(df), wd, seed))\n",
    "    df = pl.concat(dfs)\n",
    "    if save is True and filename is not None:\n",
    "        df.write_parquet(filename)\n",
    "    return pl.concat(dfs), meta"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c62c1916-8f24-4525-b508-81f016e19fa3",
   "metadata": {},
   "source": [
    "# reload everything from wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "485c7fd3-4a62-43db-b42c-399c4dc90601",
   "metadata": {},
   "outputs": [],
   "source": [
    "runs = {}\n",
    "\n",
    "with open('results/omseries-train.csv', mode='r') as file:\n",
    "    reader = csv.DictReader(file, delimiter=',')  # Ensure the correct delimiter\n",
    "    for row in reader:\n",
    "        run_id = row[\"ID\"]\n",
    "        if row[\"State\"] != \"finished\": continue # skip unfinished or crashed runs\n",
    "        # Populate the nested dictionary for each run with fallback to None if conversion fails\n",
    "        runs[run_id] = {\n",
    "            \"created\": row[\"Created\"],\n",
    "            \"runtime\": int(row[\"Runtime\"]) if row[\"Runtime\"].isdigit() else None,\n",
    "            \"backbone\": row[\"args.backbone\"],\n",
    "            \"dataset\": row[\"args.dataset\"],\n",
    "            \"wd\": float(row[\"args.weight_decay\"]) if row[\"args.weight_decay\"] else None,\n",
    "            \"seed\": int(row[\"args.seed\"]) if row[\"args.seed\"].isdigit() else None,\n",
    "            \"max_length\": int(row[\"args.max_length\"]) if row[\"args.max_length\"].isdigit() else None,\n",
    "            \"per_device_batch_size\": int(row[\"args.per_device_batch_size\"]) if row[\"args.per_device_batch_size\"].isdigit() else None,\n",
    "            \"grad_accumulation_steps\": int(row[\"args.gradient_accumulation_steps\"]) if row[\"args.gradient_accumulation_steps\"].isdigit() else None,\n",
    "            \"learning_rate\": float(row[\"args.learning_rate\"]) if row[\"args.learning_rate\"] else None,\n",
    "            \"logging_steps\": int(row[\"args.logging_steps\"]) if row[\"args.logging_steps\"].isdigit() else None,\n",
    "            \"vocab_size\": int(row[\"args.vocab_size\"]) if row[\"args.vocab_size\"].isdigit() else None,\n",
    "            \"max_training_steps\": int(row[\"args.max_training_steps\"]) if row[\"args.max_training_steps\"].isdigit() else None,\n",
    "            \"train_loss\": float(row[\"train/loss\"]) if row[\"train/loss\"] else None\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eb6b0da",
   "metadata": {},
   "outputs": [],
   "source": [
    "df, meta = load_all_runs(runs, filename=\"results/omseries-train.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4026e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
