{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2023-2024 Bytedance Ltd. and/or its affiliates \n",
    "\n",
    "\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\"); \n",
    "# you may not use this file except in compliance with the License. \n",
    "# You may obtain a copy of the License at \n",
    "\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0 \n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software \n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS, \n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n",
    "# See the License for the specific language governing permissions and \n",
    "# limitations under the License. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tool_function import *\n",
    "from dequant_function import *\n",
    "import os\n",
    "from pytorch_memlab import LineProfiler, profile\n",
    "dev = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unittest_deqaunt_cuda(param_list, tensor_buffer, tensor_buffer_listview, groupsize=128, quant_bits=4, quant_module=None, hadamard=False):\n",
    "    tensor_type = tensor_buffer.dtype\n",
    "    dequant_tensor_cuda = torch.empty_like(tensor_buffer)\n",
    "\n",
    "    # for i in range(len(param_list)):\n",
    "    #     tensor_buffer_listview[i].sub_(param_list[i])\n",
    "    torch._foreach_sub_(tensor_buffer_listview, param_list)\n",
    "    print('after sub', tensor_buffer)\n",
    "\n",
    "    if hadamard is True:\n",
    "        h_tensor = tensor_buffer.clone()\n",
    "        h_tensor = fast_hadamard_transform(h_tensor, k=5, normalize=True)\n",
    "\n",
    "    # stochastic quantize kernel\n",
    "    N = tensor_buffer.nelement()\n",
    "    groups = N // groupsize\n",
    "    quant_tensor_cuda, quant_scales_cuda = quant_module.stochastic_quantize(tensor_buffer, groups, quant_bits, quant_module.Symmetric)\n",
    "    \n",
    "    if tensor_type is torch.bfloat16:\n",
    "        quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float32:\n",
    "        quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float16:\n",
    "        quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)\n",
    "    else:\n",
    "        assert(False), \"dequant_type is not supported\"\n",
    "\n",
    "    if hadamard is True:\n",
    "        dequant_tensor_cuda = fast_hadamard_transform(dequant_tensor_cuda, k=5, normalize=True)\n",
    "\n",
    "    abs_error_norm, rela_error_norm = analysis_diff(tensor_buffer, dequant_tensor_cuda)\n",
    "    print(f\"cuda version quantization, absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}\")\n",
    "    return dequant_tensor_cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unittest_deqaunt_cuda_fused(param_list, tensor_buffer, tensor_buffer_listview, groupsize=128, quant_bits=4, quant_module=None, hadamard=False):\n",
    "    dp_size = 4\n",
    "    dp_rank = 3\n",
    "    assert tensor_buffer.numel() % dp_size == 0\n",
    "    tensor_buffer_dp_view = []\n",
    "    for i in range(dp_size):\n",
    "        param_buffer_size = tensor_buffer.numel() // dp_size\n",
    "        tensor_buffer_dp_view.append(tensor_buffer[i*param_buffer_size: (i+1)*param_buffer_size])\n",
    "\n",
    "    tensor_type = tensor_buffer.dtype\n",
    "    dequant_tensor_cuda = torch.empty_like(tensor_buffer_dp_view[dp_rank])\n",
    "\n",
    "    if hadamard is True:\n",
    "        h_tensor = tensor_buffer.clone()\n",
    "        h_tensor = fast_hadamard_transform(h_tensor, k=5, normalize=True)\n",
    "\n",
    "    # stochastic quantize kernel\n",
    "    N = tensor_buffer.nelement()\n",
    "    groups = N // groupsize\n",
    "    dp_param_offset = dp_rank * tensor_buffer.numel() // dp_size\n",
    "    # dp_param_offset = dp_rank * tensor_buffer.numel()\n",
    "    quant_tensor_cuda, quant_scales_cuda = quant_module.sub_quantize(tensor_buffer_dp_view[dp_rank], param_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)\n",
    "    \n",
    "    if tensor_type is torch.bfloat16:\n",
    "        quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float32:\n",
    "        quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)\n",
    "    elif tensor_type is torch.float16:\n",
    "        quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)\n",
    "    else:\n",
    "        assert(False), \"dequant_type is not supported\"\n",
    "\n",
    "    if hadamard is True:\n",
    "        dequant_tensor_cuda = fast_hadamard_transform(dequant_tensor_cuda, k=5, normalize=True)\n",
    "\n",
    "    torch._foreach_sub_(tensor_buffer_listview, param_list)\n",
    "    abs_error_norm, rela_error_norm = analysis_diff(tensor_buffer_dp_view[dp_rank], dequant_tensor_cuda)\n",
    "    print(f\"cuda version quantization, absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}\")\n",
    "    return dequant_tensor_cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pkg_path = \"../../tools/jet_quant_cuda\"\n",
    "print('pkg path:', pkg_path)\n",
    "quantization_module = build_and_import_module(pkg_path, 'quantization_cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor1 = torch.randn((1024 * 8,), dtype=torch.bfloat16, device=dev)\n",
    "tensor2 = torch.randn((1024 * 16,), dtype=torch.bfloat16, device=dev)\n",
    "tensor3 = torch.randn((1024 * 32,), dtype=torch.bfloat16, device=dev)\n",
    "\n",
    "param_list = [tensor1, tensor2, tensor3]\n",
    "total_len = sum([tensor.numel() for tensor in param_list])\n",
    "print(f\"total len: {total_len}\")\n",
    "\n",
    "param_buffer = torch.empty(size=(total_len + 2048,), dtype=torch.bfloat16, device=dev)\n",
    "param_buffer_list_view = []\n",
    "\n",
    "offset = 0\n",
    "for i in range(len(param_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_list[i].numel()\n",
    "    end_idx = offset\n",
    "    param_buffer_list_view.append(param_buffer[start_idx:end_idx])\n",
    "\n",
    "    param_buffer_list_view[-1].copy_(param_list[i]*1.1)\n",
    "\n",
    "print(param_list[0])\n",
    "print(param_buffer_list_view[0])\n",
    "print(param_buffer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_buffer = param_buffer.clone()\n",
    "output_buffer_list_view = []\n",
    "\n",
    "offset = 0\n",
    "for i in range(len(param_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_list[i].numel()\n",
    "    end_idx = offset\n",
    "    output_buffer_list_view.append(output_buffer[start_idx:end_idx])\n",
    "\n",
    "dequant_tensor = unittest_deqaunt_cuda(param_list=param_list, tensor_buffer=output_buffer, tensor_buffer_listview=output_buffer_list_view, groupsize = 2048, quant_bits = 4, quant_module=quantization_module, hadamard=False )\n",
    "print(dequant_tensor, torch.norm(dequant_tensor))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_buffer = param_buffer.clone()\n",
    "output_buffer_list_view = []\n",
    "\n",
    "offset = 0\n",
    "for i in range(len(param_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_list[i].numel()\n",
    "    end_idx = offset\n",
    "    output_buffer_list_view.append(output_buffer[start_idx:end_idx])\n",
    "\n",
    "dequant_tensor = unittest_deqaunt_cuda_fused(param_list=param_list, tensor_buffer=output_buffer, tensor_buffer_listview=output_buffer_list_view, groupsize = 2048, quant_bits = 4, quant_module=quantization_module, hadamard=False )\n",
    "print(dequant_tensor, torch.norm(dequant_tensor))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def functionA(func, *args, **kwargs):\n",
    "    \"\"\"\n",
    "    Measures the GPU running time of the given function using CUDA events.\n",
    "    \n",
    "    Parameters:\n",
    "    - func: The function to be measured.\n",
    "    - *args: Positional arguments to pass to the function.\n",
    "    - **kwargs: Keyword arguments to pass to the function.\n",
    "    \n",
    "    Returns:\n",
    "    - result: The result of the function execution.\n",
    "    - elapsed_time_ms: The time taken to execute the function on the GPU in milliseconds.\n",
    "    \"\"\"\n",
    "    # Ensure CUDA is available\n",
    "    if not torch.cuda.is_available():\n",
    "        raise RuntimeError(\"CUDA is not available.\")\n",
    "    \n",
    "    active = 20\n",
    "\n",
    "    # Create CUDA events for timing\n",
    "    # start_event = torch.cuda.Event(enable_timing=True)\n",
    "    # end_event = torch.cuda.Event(enable_timing=True)\n",
    "    \n",
    "    # warmup\n",
    "    for _ in range(10):\n",
    "        _ = func(*args, **kwargs)\n",
    "\n",
    "    # Synchronize and empty the cache before starting\n",
    "    torch.cuda.synchronize(device=dev)\n",
    "    \n",
    "    # Record the start event\n",
    "    # start_event.record()\n",
    "    begin = time.time()\n",
    "    # Call the function with provided arguments\n",
    "    for _ in range(active):\n",
    "        result = func(*args, **kwargs)\n",
    "    torch.cuda.synchronize()\n",
    "    elapsed_time_s = (time.time()-begin)\n",
    "    elapsed_time_s = elapsed_time_s / active\n",
    "    elapsed_time_ms = elapsed_time_s * 1000\n",
    "    # # Record the end event\n",
    "    # end_event.record()\n",
    "    \n",
    "    # # Wait for the events to be recorded\n",
    "    # torch.cuda.synchronize(device=dev)\n",
    "    \n",
    "    # Calculate the elapsed time\n",
    "    # elapsed_time_ms = start_event.elapsed_time(end_event)\n",
    "    # elapsed_time_ms = elapsed_time_ms / active\n",
    "\n",
    "    return result, elapsed_time_ms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_model_list = []\n",
    "hidden_size = 2048\n",
    "for i in range(24):\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "    param_model_list.append(torch.randn((hidden_size * 3 * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "    param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))\n",
    "    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "\n",
    "# for i in range(1):\n",
    "#     param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "#     param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))\n",
    "#     param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Speed Test for Stoquantize\n",
    "def unfused_stoquantize(param_buffer_tensor, param_buffer_list_view, param_model_list, groups, quant_bits, quant_mode):\n",
    "    for i in range(len(param_model_list)):\n",
    "        param_buffer_list_view[i].sub_(param_model_list[i])\n",
    "    # torch._foreach_sub_(param_buffer_list_view, param_model_list)\n",
    "    quantization_module.stochastic_quantize(param_buffer_tensor, groups, quant_bits, quant_mode)\n",
    "\n",
    "dp_size = 4\n",
    "dp_rank = 3\n",
    "\n",
    "total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size\n",
    "print(f\"total len: {total_len}\")\n",
    "print(f\"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}\")\n",
    "\n",
    "\n",
    "param_buffer_tensor = torch.empty(size=(total_len * dp_size + 2048,), dtype=torch.bfloat16, device=dev)\n",
    "tensor_buffer_dp_view = []\n",
    "for i in range(dp_size):\n",
    "    param_buffer_size = param_buffer_tensor.numel() // dp_size\n",
    "    tensor_buffer_dp_view.append(param_buffer_tensor[i*param_buffer_size: (i+1)*param_buffer_size])\n",
    "\n",
    "param_buffer_list_view = []\n",
    "offset = 0\n",
    "for i in range(len(param_model_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_model_list[i].numel()\n",
    "    end_idx = offset\n",
    "    param_buffer_list_view.append(param_buffer_tensor[start_idx:end_idx])\n",
    "\n",
    "    param_buffer_list_view[-1].copy_(param_model_list[i]*1.1)\n",
    "\n",
    "\n",
    "N = total_len\n",
    "quant_bits = 4\n",
    "groupsize = 1024\n",
    "groups = N // groupsize\n",
    "_, avg_time = functionA(unfused_stoquantize, tensor_buffer_dp_view[dp_rank], param_buffer_list_view, param_model_list, groups, quant_bits, quantization_module.Symmetric)\n",
    "\n",
    "num_bytes = param_buffer_tensor.numel() * param_buffer_tensor.element_size()\n",
    "print('Sto Quantize')\n",
    "print(f'time: {avg_time}ms')\n",
    "print(f'numbytes: {num_bytes}Bytes')\n",
    "print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Speed Test for Stoquantize\n",
    "\n",
    "dp_size = 4\n",
    "dp_rank = 3\n",
    "\n",
    "total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size\n",
    "print(f\"total len: {total_len}\")\n",
    "print(f\"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}\")\n",
    "\n",
    "\n",
    "param_buffer_tensor = torch.empty(size=(total_len * dp_size + 2048,), dtype=torch.bfloat16, device=dev)\n",
    "tensor_buffer_dp_view = []\n",
    "for i in range(dp_size):\n",
    "    param_buffer_size = param_buffer_tensor.numel() // dp_size\n",
    "    tensor_buffer_dp_view.append(param_buffer_tensor[i*param_buffer_size: (i+1)*param_buffer_size])\n",
    "\n",
    "param_buffer_list_view = []\n",
    "offset = 0\n",
    "for i in range(len(param_model_list)):\n",
    "    start_idx = offset\n",
    "    offset += param_model_list[i].numel()\n",
    "    end_idx = offset\n",
    "    param_buffer_list_view.append(param_buffer_tensor[start_idx:end_idx])\n",
    "\n",
    "    param_buffer_list_view[-1].copy_(param_model_list[i]*1.1)\n",
    "\n",
    "N = param_buffer_tensor.nelement()\n",
    "quant_bits = 4\n",
    "groupsize = 1024\n",
    "groups = N // groupsize\n",
    "_, avg_time = functionA(quantization_module.sub_quantize, tensor_buffer_dp_view[dp_rank], param_model_list, dp_rank*N//dp_size, groups, quant_bits, quantization_module.Symmetric)\n",
    "\n",
    "num_bytes = param_buffer_tensor.numel() * param_buffer_tensor.element_size()\n",
    "print('Sto Quantize')\n",
    "print(f'time: {avg_time}ms')\n",
    "print(f'numbytes: {num_bytes}Bytes')\n",
    "print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "megatron",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
