{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jamesliu/anaconda3/envs/sparsedoping/lib/python3.11/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import sys\n",
    "sys.path.append('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a model of type llama to instantiate a model of type llama_sparse. This is not supported for all configurations of models and can yield errors.\n",
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.30it/s]\n"
     ]
    }
   ],
   "source": [
    "from sparsedoping.model import LlamaSparseModelForCausalLM\n",
    "\n",
    "from sparsedoping.model import LlamaSparseModelForCausalLM, LlamaSparseConfig\n",
    "\n",
    "from transformers import AutoConfig, AutoModelForCausalLM\n",
    "\n",
    "AutoConfig.register(\"llama_sparse\", LlamaSparseConfig)\n",
    "AutoModelForCausalLM.register(LlamaSparseConfig, LlamaSparseModelForCausalLM)\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7B-hf\")\n",
    "model = LlamaSparseModelForCausalLM.from_pretrained(\n",
    "    \"meta-llama/Llama-2-7B-hf\", \n",
    "    torch_dtype=torch.float16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    attn_implementation=\"flash_attention_2\",\n",
    "    histogram_path=\"/home/jamesliu/models/Llama-2-7B/histograms\", \n",
    "    apply_prefill=False,\n",
    ").to(\"cuda\")\n",
    "\n",
    "greedy_sparsity_path = \"/home/jamesliu/models/Llama-2-7B/lookup\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):\n",
    "    logits = model(\n",
    "        cur_token,\n",
    "        position_ids=input_pos,\n",
    "        cache_position=cache_position,\n",
    "        past_key_values=past_key_values,\n",
    "        return_dict=False,\n",
    "        use_cache=True\n",
    "    )[0]\n",
    "    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]\n",
    "    return new_token\n",
    "\n",
    "def graph_wrapper(fn, *init_args, **init_kwargs):\n",
    "    s = torch.cuda.Stream(device=\"cuda\")\n",
    "    s.wait_stream(torch.cuda.current_stream())\n",
    "    with torch.cuda.stream(s):\n",
    "        fn(*init_args, **init_kwargs)\n",
    "\n",
    "    torch.cuda.current_stream().wait_stream(s)\n",
    "    graph = torch.cuda.CUDAGraph()\n",
    "\n",
    "    with torch.cuda.graph(graph, stream=s):\n",
    "        static_output = fn(*init_args, **init_kwargs)\n",
    "\n",
    "    static_args = init_args\n",
    "    static_kwargs = init_kwargs\n",
    "    \n",
    "\n",
    "    def replay(*args, **kwargs):\n",
    "        for i in range(len(args)):\n",
    "            if isinstance(args[i], torch.Tensor):\n",
    "                static_args[i].copy_(args[i])\n",
    "        for kw in kwargs:\n",
    "            if isinstance(kwargs[kw], torch.Tensor):\n",
    "                static_kwargs[kw].copy_(kwargs[kw])\n",
    "\n",
    "        graph.replay()\n",
    "        return static_output\n",
    "    \n",
    "    return replay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import get_layer_greedy_sparsities\n",
    "\n",
    "from data import get_dataset\n",
    "from transformers import StaticCache\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "import torch.cuda.nvtx as nvtx\n",
    "\n",
    "\n",
    "# assumes model mode is default or turbo\n",
    "def eval_speed(model, tokenizer, device, sparsities, dataset, debug=False):\n",
    "    nvtx.range_push(\"eval_speed\")\n",
    "    model.eval()\n",
    "    acc_0_list = []\n",
    "    nlls = []\n",
    "\n",
    "\n",
    "    text = \"\"\n",
    "    for sample in dataset:\n",
    "        text += sample[\"text\"] + \"\\n\\n\"\n",
    "\n",
    "    \n",
    "    prefill_len = 128\n",
    "    max_len = prefill_len + 512\n",
    "\n",
    "\n",
    "    all_encodings = tokenizer(text, return_tensors=\"pt\", max_length=max_len).to(device)\n",
    "    prefill_encodings = tokenizer(text, return_tensors=\"pt\", max_length=prefill_len).to(device)\n",
    "\n",
    "    nvtx.range_push(\"model_setup\")\n",
    "    model.reset_sparsities()\n",
    "\n",
    "    model.set_sparsities(sparsities)\n",
    "\n",
    "    # print(vars(encodings))\n",
    "\n",
    "    batch_size = 1\n",
    "    with torch.no_grad():\n",
    "        nvtx.range_push(\"prefill\")\n",
    "        past_key_values = StaticCache(\n",
    "            config=model.config, max_batch_size=1, max_cache_len=4096, device=model.device, dtype=model.dtype\n",
    "        )\n",
    "        cache_position = torch.arange(prefill_len, device=model.device)\n",
    "        generated_ids = torch.zeros(\n",
    "            batch_size, max_len, dtype=torch.int, device=model.device\n",
    "        ) # bsz, num tokens total\n",
    "\n",
    "        generated_ids[:, cache_position] = prefill_encodings[\"input_ids\"].to(model.device).to(torch.int)\n",
    "\n",
    "\n",
    "        logits = model(\n",
    "            **prefill_encodings,\n",
    "            cache_position=cache_position,\n",
    "            past_key_values=past_key_values,\n",
    "            return_dict=False,\n",
    "            use_cache=True\n",
    "        )[0]\n",
    "\n",
    "        # potentially do graph capture\n",
    "        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]\n",
    "        generated_ids[:, prefill_len] = next_token[:, 0]\n",
    "        cache_position = torch.tensor([prefill_len + 1], device=model.device)\n",
    "\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "        wrapped_decode_one_tokens = graph_wrapper(\n",
    "            decode_one_tokens,\n",
    "            model,\n",
    "            next_token,\n",
    "            None,\n",
    "            cache_position,\n",
    "            past_key_values\n",
    "        )\n",
    "\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "        from time import time\n",
    "        generated_tokens = 0\n",
    "\n",
    "        torch.cuda.synchronize()\n",
    "        start_time = time()\n",
    "\n",
    "        # tok_time = time()\n",
    "        nvtx.range_push(\"token_generation_loop\")\n",
    "        for i in tqdm(range(0, max_len-prefill_len-1)):\n",
    "            nvtx.range_push(f\"token_{i}\")\n",
    "            # next_token = wrapped_decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)\n",
    "            # generated_ids[:, cache_position] = all_encodings.input_ids[:, cache_position].int()\n",
    "            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n",
    "                # _ = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)\n",
    "                next_token = wrapped_decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)\n",
    "\n",
    "                # generated_ids[:, cache_position] = next_token.int()\n",
    "                generated_ids[:, cache_position] = all_encodings.input_ids[:, cache_position].int()\n",
    "            cache_position += 1\n",
    "            generated_tokens += 1\n",
    "\n",
    "            next_tok_time = time()\n",
    "            # print(f\"{next_tok_time-tok_time}\")\n",
    "            # tok_time = next_tok_time\n",
    "        \n",
    "        torch.cuda.synchronize()\n",
    "        end_time = time()\n",
    "        tokens_per_second = generated_tokens / (end_time - start_time)\n",
    "        return tokens_per_second, 1/tokens_per_second"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 895/895 [00:08<00:00, 110.81it/s] \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(96.70615449744425, 0.010340603503435018)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from data import get_dataset\n",
    "from utils import get_layer_greedy_sparsities\n",
    "dataset = get_dataset(\n",
    "    \"wikitext\",\n",
    "    subset=\"wikitext-2-raw-v1\",\n",
    "    split=\"train\",\n",
    "    size=100,\n",
    ")\n",
    "\n",
    "model.convert_column_mode()\n",
    "\n",
    "sparsity_level = 0.5\n",
    "projs = ['up', 'gate','down','q','k','v','o']\n",
    "sparsities = {\n",
    "    proj: [sparsity_level]*len(model.model.layers) for proj in projs\n",
    "}\n",
    "\n",
    "# sparsities = get_layer_greedy_sparsities([sparsity_level]*len(model.model.layers), greedy_sparsity_path)\n",
    "\n",
    "# model.set_sparsity_mode(\"default\")\n",
    "model.set_sparsity_mode(\"turbo\")\n",
    "\n",
    "eval_speed(model, tokenizer, \"cuda\", sparsities, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [00:01, 85.75it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "6.059079046405714"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from data import get_dataset\n",
    "dataset = get_dataset(\n",
    "    \"wikitext\",\n",
    "    subset=\"wikitext-2-raw-v1\",\n",
    "    split=\"train\",\n",
    "    size=100,\n",
    ")\n",
    "\n",
    "model.set_sparsity_mode(\"turbo\")\n",
    "sparsity_level = 0\n",
    "projs = ['up', 'gate','down','q','k','v','o']\n",
    "sparsities = {\n",
    "    proj: [sparsity_level]*len(model.model.layers) for proj in projs\n",
    "}\n",
    "\n",
    "model.set_sparsities(sparsities)\n",
    "# model.set_sparsity_mode(\"dev\")\n",
    "\n",
    "from eval_ppl import eval_ppl\n",
    "eval_ppl(model, tokenizer, \"cuda\", dataset)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sparsedoping",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
