{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "\n",
    "def process_files(file_paths, threshold=1e-4, batch_size=10):\n",
    "    \"\"\"\n",
    "    file_paths: list of 120 json file paths\n",
    "    threshold: loss cutoff\n",
    "    batch_size: number of files per batch (default = 120/10 = 12)\n",
    "    \"\"\"\n",
    "    times = []\n",
    "\n",
    "    for path in file_paths:\n",
    "        with open(path, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "\n",
    "        # Extract values\n",
    "        losses = np.array(data[\"loss\"][\"values\"])\n",
    "        times_arr = np.array(data[\"time\"][\"values\"])\n",
    "\n",
    "        # Find first time where loss < threshold\n",
    "        idx = np.where(losses < threshold)[0]\n",
    "        if len(idx) > 0:\n",
    "            times.append(times_arr[idx[0]])\n",
    "        else:\n",
    "            times.append(np.nan)  # in case threshold never reached\n",
    "\n",
    "    # Group into batches\n",
    "    batches = [\n",
    "        times[i:i + batch_size]\n",
    "        for i in range(0, len(times), batch_size)\n",
    "    ]\n",
    "\n",
    "    # Log each batch and compute mean\n",
    "    batch_means = []\n",
    "    for i, batch in enumerate(batches):\n",
    "        arr = np.array(batch, dtype=float)\n",
    "        print(f\"Batch {i+1}: times = {arr}\")\n",
    "        mean_time = np.nanmean(arr)\n",
    "        print(f\"Batch {i+1} mean = {mean_time:.4f}\")\n",
    "        batch_means.append(mean_time)\n",
    "\n",
    "    overall_mean = np.nanmean(batch_means)\n",
    "    print(f\"\\nOverall mean across batches = {overall_mean:.4f}\")\n",
    "\n",
    "    return batch_means, overall_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1: times = [0.06433331 0.06380918 0.06398259 0.06385782 0.06381661 0.06371411\n",
      " 0.06401379 0.06376701        nan        nan]\n",
      "Batch 1 mean = 0.0639\n",
      "Batch 2: times = [0.15618534 0.15703933 0.15710179 0.15726973 0.15899584 0.15609379\n",
      " 0.15930368 0.15552842 0.15515113 0.15825661]\n",
      "Batch 2 mean = 0.1571\n",
      "Batch 3: times = [0.05489043 0.0552111  0.05499523 0.05496618 0.05519235 0.05525626\n",
      " 0.05514195 0.05511456 0.05515443 0.05505971]\n",
      "Batch 3 mean = 0.0551\n",
      "Batch 4: times = [0.17859629 0.18107475 0.17730333 0.17902858 0.17812272 0.17927683\n",
      " 0.18048915 0.17897968 0.1768576  0.18100422]\n",
      "Batch 4 mean = 0.1791\n",
      "Batch 5: times = [nan nan nan nan nan nan nan nan nan nan]\n",
      "Batch 5 mean = nan\n",
      "Batch 6: times = [0.05669987 0.05698691 0.05664186 0.05693373 0.0566897  0.05698006\n",
      " 0.0569728  0.05681322 0.05659814 0.056792  ]\n",
      "Batch 6 mean = 0.0568\n",
      "Batch 7: times = [       nan        nan        nan        nan        nan        nan\n",
      "        nan 0.06246771        nan 0.06246054]\n",
      "Batch 7 mean = 0.0625\n",
      "Batch 8: times = [nan nan nan nan nan nan nan nan nan nan]\n",
      "Batch 8 mean = nan\n",
      "Batch 9: times = [nan nan nan nan nan nan nan nan nan nan]\n",
      "Batch 9 mean = nan\n",
      "Batch 10: times = [nan nan nan nan nan nan nan nan nan nan]\n",
      "Batch 10 mean = nan\n",
      "Batch 11: times = [0.05693203 0.05693478 0.05695206 0.0569168  0.05656054 0.05665146\n",
      " 0.05671885 0.05678358 0.05674566 0.05691437]\n",
      "Batch 11 mean = 0.0568\n",
      "Batch 12: times = [nan nan nan nan nan nan nan nan nan nan]\n",
      "Batch 12 mean = nan\n",
      "\n",
      "Overall mean across batches = 0.0902\n",
      "[0.06391180396080018, 0.15709256610870362, 0.05509822082519531, 0.1790733154296875, nan, 0.0568108286857605, 0.06246412825584412, nan, nan, nan, 0.056811014556884774, nan]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1651168/39230095.py:38: RuntimeWarning: Mean of empty slice\n",
      "  mean_time = np.nanmean(arr)\n"
     ]
    }
   ],
   "source": [
    "paths = []\n",
    "for i in range(120):\n",
    "    it = i + 7077\n",
    "    paths.append(f\"file_storage/runs/{it}/metrics.json\")\n",
    "\n",
    "batch_means, overall_mean = process_files(paths)\n",
    "print(batch_means)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "def CIFAR(c, d, eps):\n",
    "\n",
    "    trainset = torchvision.datasets.CIFAR10(\n",
    "        root=\"./data\",\n",
    "        train=True,\n",
    "        download=True,\n",
    "        transform=transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "            ]\n",
    "        ),\n",
    "    )\n",
    "    trainloader = torch.utils.data.DataLoader(trainset, batch_size=d, shuffle=True)\n",
    "    images, _ = next(iter(trainloader))\n",
    "\n",
    "    images = torch.reshape(images, (d, -1))\n",
    "    col_means = images.mean(dim=0, keepdim=True)   # [1, image_size]\n",
    "    images = images - col_means\n",
    "    \n",
    "    I = torch.eye(3072)\n",
    "    A = (images.T @ images) / d\n",
    "    A = A / torch.linalg.norm(A, \"fro\")\n",
    "    A = A + eps * I\n",
    "\n",
    "    return A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inverse_square_root(A, device=\"cuda\", warmup=3, iters=10):\n",
    "    \"\"\"\n",
    "    Compute the matrix inverse square root using eigen-decomposition\n",
    "    with accurate GPU timing.\n",
    "    \n",
    "    Args:\n",
    "        A (torch.Tensor): Symmetric positive-definite matrix (n x n).\n",
    "        device (str): \"cuda\" or \"cpu\".\n",
    "        warmup (int): Number of warm-up runs (not timed).\n",
    "        iters (int): Number of timed runs to average.\n",
    "\n",
    "    Returns:\n",
    "        X (torch.Tensor): Matrix inverse square root of A.\n",
    "        avg_time (float): Average runtime in milliseconds.\n",
    "    \"\"\"\n",
    "    A = A.to(device)\n",
    "\n",
    "    def compute(A):\n",
    "        eigvals, eigvecs = torch.linalg.eigh(A)\n",
    "        inv_sqrt = torch.diag_embed(eigvals.rsqrt())\n",
    "        return eigvecs @ inv_sqrt @ eigvecs.T\n",
    "\n",
    "    # Warm-up\n",
    "    for _ in range(warmup):\n",
    "        _ = compute(A)\n",
    "        if device == \"cuda\":\n",
    "            torch.cuda.synchronize()\n",
    "\n",
    "    # Timed runs with CUDA events\n",
    "    if device == \"cuda\":\n",
    "        start_event = torch.cuda.Event(enable_timing=True)\n",
    "        end_event = torch.cuda.Event(enable_timing=True)\n",
    "\n",
    "        start_event.record()\n",
    "        for _ in range(iters):\n",
    "            X = compute(A)\n",
    "        end_event.record()\n",
    "        torch.cuda.synchronize()\n",
    "        avg_time = start_event.elapsed_time(end_event) / iters  # ms\n",
    "    else:\n",
    "        import time\n",
    "        start = time.perf_counter()\n",
    "        for _ in range(iters):\n",
    "            X = compute(A)\n",
    "        end = time.perf_counter()\n",
    "        avg_time = (end - start) * 1000 / iters  # ms\n",
    "\n",
    "    return X, avg_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "[106.60761260986328, 107.0929946899414, 106.65164947509766, 107.14112091064453, 107.97875213623047, 107.3960952758789, 106.7294692993164, 107.53228759765625, 107.0387191772461, 107.54662322998047]\n",
      "107.17153244018554\n"
     ]
    }
   ],
   "source": [
    "times = []\n",
    "for i in range(10):\n",
    "    A = CIFAR(0.25, 5000, 1e-3)\n",
    "    _, time = inverse_square_root(A, warmup = 10, iters = 1)\n",
    "    times.append(time)\n",
    "\n",
    "print(times)\n",
    "print(np.mean(times))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
