{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from functools import partial\n",
    "from utils.backup_analysis import load_model\n",
    "from sys import getsizeof\n",
    "from collections.abc import Mapping, Container\n",
    "from utils.component_evaluation import compute_copy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_gpu_memory_usage(label=\"\", device=\"cuda:0\"):\n",
    "    allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)  # Convert bytes to GB\n",
    "    reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)\n",
    "    print(f\"{label} - Memory Allocated: {allocated:.2f} GB, Memory Reserved: {reserved:.2f} GB\")\n",
    "\n",
    "def tensor_memory_size(tensor):\n",
    "    # Get the number of elements in the tensor\n",
    "    num_elements = tensor.numel()\n",
    "    # Determine the size of each element based on the data type\n",
    "    element_size = tensor.element_size()  # Returns size in bytes per element\n",
    "    # Calculate total memory footprint\n",
    "    memory_size = num_elements * element_size\n",
    "    return memory_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_MODEL = 'pythia-2.8b'\n",
    "VARIANT = None\n",
    "CACHE = 'model_cache'\n",
    "DEVICE = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(BASE_MODEL, VARIANT, 143000, CACHE, DEVICE, large_model=True)\n",
    "print_gpu_memory_usage(\"After loading model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.data_utils import generate_data_and_caches\n",
    "model.tokenizer.add_bos_token = False\n",
    "ioi_dataset, abc_dataset = generate_data_and_caches(model, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_of_heads = [(9, 4), (12, 2)]\n",
    "copy_scores = compute_copy_score(model, list_of_heads, ioi_dataset=ioi_dataset, batch_size=12, verbose=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print_gpu_memory_usage(\"After computing copy scores\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from utils.head_metrics import BatchIOIDataset, collate_fn\n",
    "# ioi_dataset.__class__ = BatchIOIDataset\n",
    "# ioi_dataloader = DataLoader(ioi_dataset, batch_size=20, collate_fn=partial(collate_fn, device=model.cfg.device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import ActivationCache\n",
    "from typing import List\n",
    "\n",
    "def concatenate_activation_caches(caches: List[ActivationCache]) -> ActivationCache:\n",
    "    # input to ActivationCache should be a dictionary where the keys are individual keys of\n",
    "    # caches[0] and the values are concatenated across all batches via the first dimension\n",
    "    return ActivationCache({k: torch.cat([c[k] for c in caches], dim=0) for k in caches[0].keys()}, model.cfg.model_name, True)\n",
    "\n",
    "def run_with_cache_batched(model: HookedTransformer, input_ids: torch.Tensor, batch_size: int = 20):\n",
    "    results = []\n",
    "    caches = []\n",
    "    total_size = input_ids.shape[0]\n",
    "    for i in range(0, total_size, batch_size):\n",
    "        # Adjust batch end index to avoid going out of bounds\n",
    "        batch_input = input_ids[i:min(i+batch_size, total_size)]\n",
    "        batch_logits, batch_cache = model.run_with_cache(batch_input)\n",
    "        # send to CPU\n",
    "        batch_logits = batch_logits.cpu()\n",
    "        batch_cache = batch_cache.to('cpu')\n",
    "        print_gpu_memory_usage(f\"Batch {i//batch_size}\")\n",
    "        results.append(batch_logits)\n",
    "        caches.append(batch_cache)\n",
    "    return torch.cat(results, dim=0), concatenate_activation_caches(caches)\n",
    "\n",
    "logits_from_batched, cache_from_batched = run_with_cache_batched(model, ioi_dataset.toks.long(), batch_size=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "running_size = 0\n",
    "for k in cache_from_batched.keys():\n",
    "    #print(f\"Key: {k}, Size: {tensor_memory_size(cache_from_batched[k]) / (1024 ** 3):.2f} GB\")\n",
    "    #print(tensor_memory_size(cache_from_batched[k]))\n",
    "    running_size += tensor_memory_size(cache_from_batched[k])\n",
    "    #print(f\"New running size: {running_size} bytes\")\n",
    "print(f\"Total size: {running_size / (1024 ** 3):.2f} GB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "running_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits, cache = model.run_with_cache(ioi_dataset.toks.long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cache['hook_embed'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in cache.keys():\n",
    "    if cache[k].shape[0] != 100:\n",
    "        print(k, cache[k].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(logits, logits_from_batched, atol=1e-01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in cache.keys():\n",
    "    if not torch.allclose(cache[k], cache_from_batched[k], atol=1e-01):\n",
    "        print(k, cache[k].shape, cache_from_batched[k].shape)\n",
    "        # if 'attn_scores' not in k:\n",
    "        for p in range(100):\n",
    "            if not torch.allclose(logits[p][ioi_dataset.word_idx['end'][p]], logits_from_batched[p][ioi_dataset.word_idx['end'][p]], atol=1e-02):\n",
    "                print(f\"index {p}\")\n",
    "                print(ioi_dataset.ioi_prompts[p])\n",
    "                print(logits[p][ioi_dataset.word_idx['end'][p]][:5], logits_from_batched[p][ioi_dataset.word_idx['end'][p]][:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cot",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
