{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2023-2024 Bytedance Ltd. and/or its affiliates \n",
    "\n",
    "\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\"); \n",
    "# you may not use this file except in compliance with the License. \n",
    "# You may obtain a copy of the License at \n",
    "\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0 \n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software \n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS, \n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n",
    "# See the License for the specific language governing permissions and \n",
    "# limitations under the License. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tool_function import *\n",
    "from dequant_function import *\n",
    "import os\n",
    "from pytorch_memlab import LineProfiler, profile\n",
    "dev = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pkg_path = \"../../tools/jet_quant_cuda\"\n",
    "print('pkg path:', pkg_path)\n",
    "quantization_module = build_and_import_module(pkg_path, 'quantization_cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def functionA(func, *args, **kwargs):\n",
    "    \"\"\"\n",
    "    Measures the GPU running time of the given function using CUDA events.\n",
    "    \n",
    "    Parameters:\n",
    "    - func: The function to be measured.\n",
    "    - *args: Positional arguments to pass to the function.\n",
    "    - **kwargs: Keyword arguments to pass to the function.\n",
    "    \n",
    "    Returns:\n",
    "    - result: The result of the function execution.\n",
    "    - elapsed_time_ms: The time taken to execute the function on the GPU in milliseconds.\n",
    "    \"\"\"\n",
    "    # Ensure CUDA is available\n",
    "    if not torch.cuda.is_available():\n",
    "        raise RuntimeError(\"CUDA is not available.\")\n",
    "    \n",
    "    active = 20\n",
    "\n",
    "    # Create CUDA events for timing\n",
    "    # start_event = torch.cuda.Event(enable_timing=True)\n",
    "    # end_event = torch.cuda.Event(enable_timing=True)\n",
    "    \n",
    "    # warmup\n",
    "    for _ in range(10):\n",
    "        _ = func(*args, **kwargs)\n",
    "\n",
    "    # Synchronize and empty the cache before starting\n",
    "    torch.cuda.synchronize(device=dev)\n",
    "    \n",
    "    # Record the start event\n",
    "    # start_event.record()\n",
    "    begin = time.time()\n",
    "    # Call the function with provided arguments\n",
    "    for _ in range(active):\n",
    "        result = func(*args, **kwargs)\n",
    "    torch.cuda.synchronize()\n",
    "    elapsed_time_s = (time.time()-begin)\n",
    "    elapsed_time_s = elapsed_time_s / active\n",
    "    elapsed_time_ms = elapsed_time_s * 1000\n",
    "    # # Record the end event\n",
    "    # end_event.record()\n",
    "    \n",
    "    # # Wait for the events to be recorded\n",
    "    # torch.cuda.synchronize(device=dev)\n",
    "    \n",
    "    # Calculate the elapsed time\n",
    "    # elapsed_time_ms = start_event.elapsed_time(end_event)\n",
    "    # elapsed_time_ms = elapsed_time_ms / active\n",
    "\n",
    "    return result, elapsed_time_ms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unittest_deqaunt_cuda_unfused(param_list, param_buffer, groupsize=128, quant_bits=4, quant_module=None):\n",
    "\n",
    "    original_param_buffer = param_buffer.clone()\n",
    "    tensor_type = param_buffer.dtype\n",
    "    param_buffer_listview = []\n",
    "    \n",
    "    offset = 0\n",
    "    for i in range(len(param_list)):\n",
    "        param = param_list[i]\n",
    "        param_buffer_listview.append(param_buffer[offset : offset+param.numel()])\n",
    "        offset += param.numel()\n",
    "\n",
    "    # for i in range(len(param_list)):\n",
    "    #     param_buffer_listview[i].sub_(param_list[i])\n",
    "    torch._foreach_sub_(param_buffer_listview, param_list)\n",
    "    original_param_buffer_delta = param_buffer.clone()\n",
    "    print('after sub', param_buffer)\n",
    "\n",
    "    # stochastic quantize kernel\n",
    "    N = param_buffer.nelement()\n",
    "    groups = N // groupsize\n",
    "    quant_tensor_cuda, quant_scales_cuda = quant_module.stochastic_quantize(param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    \n",
    "    if tensor_type is torch.bfloat16:\n",
    "        quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float32:\n",
    "        quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float16:\n",
    "        quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    else:\n",
    "        assert(False), \"dequant_type is not supported\"\n",
    "\n",
    "    abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer_delta, param_buffer)\n",
    "    print(f\"unfused dequantization&add, weight_diff absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}\")\n",
    "\n",
    "    torch._foreach_add_(param_buffer_listview, param_list)\n",
    "\n",
    "    abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer, param_buffer)\n",
    "    print(f\"unfused dequantization&add, weight_diff/weight absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unittest_deqaunt_cuda_fused(param_list, param_buffer, dp_param_offset, groupsize=128, quant_bits=4, quant_module=None):\n",
    "\n",
    "    original_param_buffer = param_buffer.clone()\n",
    "    tensor_type = param_buffer.dtype\n",
    "    param_buffer_listview = []\n",
    "    \n",
    "    offset = 0\n",
    "    for i in range(len(param_list)):\n",
    "        param = param_list[i]\n",
    "        param_buffer_listview.append(param_buffer[offset : offset+param.numel()])\n",
    "        offset += param.numel()\n",
    "\n",
    "    # for i in range(len(param_list)):\n",
    "    #     param_buffer_listview[i].sub_(param_list[i])\n",
    "    # torch._foreach_sub_(param_buffer_listview, param_list)\n",
    "    original_param_buffer_delta = param_buffer.clone()\n",
    "    print('after sub', param_buffer)\n",
    "\n",
    "    # stochastic quantize kernel\n",
    "    N = param_buffer.nelement()\n",
    "    groups = N // groupsize\n",
    "    # quant_tensor_cuda, quant_scales_cuda = quant_module.stochastic_quantize(param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    quant_tensor_cuda, quant_scales_cuda = quant_module.sub_quantize(param_buffer, param_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)\n",
    "\n",
    "    if tensor_type is torch.bfloat16:\n",
    "        quant_module.dequantize_add_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, param_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float32:\n",
    "        # quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "        assert(False), \"dequant_type is not supported\"\n",
    "    elif tensor_type is torch.float16:\n",
    "        # quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "        assert(False), \"dequant_type is not supported\"\n",
    "    else:\n",
    "        assert(False), \"dequant_type is not supported\"\n",
    "\n",
    "    # abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer_delta, param_buffer)\n",
    "    # print(f\"unfused dequantization&add, weight_diff absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}\")\n",
    "\n",
    "    # torch._foreach_add_(param_buffer_listview, param_list)\n",
    "\n",
    "    abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer, param_buffer)\n",
    "    print(f\"unfused dequantization&add, weight_diff/weight absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor1 = torch.randn((1024 * 8,), dtype=torch.bfloat16, device=dev)\n",
    "tensor2 = torch.randn((1024 * 16,), dtype=torch.bfloat16, device=dev)\n",
    "tensor3 = torch.randn((1024 * 32,), dtype=torch.bfloat16, device=dev)\n",
    "\n",
    "param_list = [tensor1, tensor2, tensor3]\n",
    "total_len = sum([tensor.numel() for tensor in param_list])\n",
    "print(f\"total len: {total_len}\")\n",
    "\n",
    "param_buffer = torch.zeros(size=(total_len + 2048,), dtype=torch.bfloat16, device=dev)\n",
    "param_buffer_list_view = []\n",
    "\n",
    "offset = 0\n",
    "for i in range(len(param_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_list[i].numel()\n",
    "    end_idx = offset\n",
    "    param_buffer[start_idx:end_idx].copy_(param_list[i]*1.1)\n",
    "\n",
    "print(param_list[0])\n",
    "print(param_buffer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp_size = 4\n",
    "dp_rank = 2\n",
    "N = param_buffer.numel() // dp_size\n",
    "# unittest_deqaunt_cuda_unfused(param_list=param_list, param_buffer=param_buffer[N*dp_rank: N*(dp_rank+1)], groupsize=2048, quant_bits=4,quant_module=quantization_module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp_size = 4\n",
    "dp_rank = 2\n",
    "N = param_buffer.numel() // dp_size\n",
    "\n",
    "unittest_deqaunt_cuda_fused(param_list=param_list, param_buffer=param_buffer[N*dp_rank: N*(dp_rank+1)], dp_param_offset=N*dp_rank, groupsize=512, quant_bits=4,quant_module=quantization_module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_model_list = []\n",
    "hidden_size = 2048\n",
    "for i in range(24):\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "    param_model_list.append(torch.randn((hidden_size * 3 * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "    param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "# for i in range(1):\n",
    "#     param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "#     param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "#     param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))\n",
    "total_len = sum([tensor.numel() for tensor in param_model_list])\n",
    "print(f\"total len: {total_len}\")\n",
    "param_buffer = torch.zeros(size=(total_len,), dtype=torch.bfloat16, device=dev)\n",
    "param_buffer_list_view = []\n",
    "\n",
    "offset = 0\n",
    "for i in range(len(param_model_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_model_list[i].numel()\n",
    "    end_idx = offset\n",
    "    param_buffer[start_idx:end_idx].copy_(param_model_list[i]*1.1)\n",
    "\n",
    "    param_buffer_list_view.append(param_buffer[start_idx:end_idx])\n",
    "\n",
    "dp_size = 1\n",
    "dp_rank = 0\n",
    "\n",
    "tensor_buffer_dp_view = []\n",
    "for i in range(dp_size):\n",
    "    param_buffer_size = param_buffer.numel() // dp_size\n",
    "    tensor_buffer_dp_view.append(param_buffer[i*param_buffer_size: (i+1)*param_buffer_size])\n",
    "\n",
    "groupsize = 512\n",
    "N = tensor_buffer_dp_view[dp_rank].nelement()\n",
    "groups = N // groupsize\n",
    "quant_tensor_cuda, quant_scales_cuda = quantization_module.sub_quantize(tensor_buffer_dp_view[dp_rank], param_list, dp_rank * tensor_buffer_dp_view[0].numel(), groups, 4, quantization_module.Symmetric)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Speed Test for Stoquantize\n",
    "def unfused_dequantize(param_buffer_list_view, param_model_list, quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module):\n",
    "    quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    torch._foreach_add_(param_buffer_list_view, param_model_list)\n",
    "    # for i in range(len(param_model_list)):\n",
    "    #     param_buffer_listview[i].add_(param_model_list[i])\n",
    "\n",
    "\n",
    "\n",
    "total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size\n",
    "print(f\"total len: {total_len}\")\n",
    "print(f\"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}\")\n",
    "\n",
    "\n",
    "N = total_len\n",
    "quant_bits = 4\n",
    "groupsize = 1024\n",
    "groups = N // groupsize\n",
    "_, avg_time = functionA(unfused_dequantize, param_buffer_list_view, param_model_list, quant_tensor_cuda, quant_scales_cuda, tensor_buffer_dp_view[dp_rank], groups, quant_bits, quantization_module)\n",
    "\n",
    "num_bytes = tensor_buffer_dp_view[dp_rank].numel() * tensor_buffer_dp_view[dp_rank].element_size()\n",
    "print('unfused Dequantization')\n",
    "print(f'time: {avg_time}ms')\n",
    "print(f'numbytes: {num_bytes}Bytes')\n",
    "print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Speed Test for Stoquantize\n",
    "def fused_dequantize(dp_param_offset, param_model_list, quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module):\n",
    "    quant_module.dequantize_add_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, param_model_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)\n",
    "    # quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    # torch._foreach_add_(param_buffer_list_view, param_model_list)\n",
    "    # for i in range(len(param_model_list)):\n",
    "    #     param_buffer_listview[i].add_(param_model_list[i])\n",
    "\n",
    "\n",
    "\n",
    "total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size\n",
    "print(f\"total len: {total_len}\")\n",
    "print(f\"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}\")\n",
    "\n",
    "\n",
    "N = total_len\n",
    "quant_bits = 4\n",
    "groupsize = 1024\n",
    "groups = N // groupsize\n",
    "_, avg_time = functionA(fused_dequantize, dp_rank * tensor_buffer_dp_view[0].numel(), param_model_list, quant_tensor_cuda, quant_scales_cuda, tensor_buffer_dp_view[dp_rank], groups, quant_bits, quantization_module)\n",
    "\n",
    "num_bytes = tensor_buffer_dp_view[dp_rank].numel() * tensor_buffer_dp_view[dp_rank].element_size()\n",
    "print('unfused Dequantization')\n",
    "print(f'time: {avg_time}ms')\n",
    "print(f'numbytes: {num_bytes}Bytes')\n",
    "print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "megatron",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
