{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import time\n",
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_PAD_TOKEN = \"<pad>\"\n",
    "DEFAULT_BOS_TOKEN = \"<s>\"\n",
    "DEFAULT_EOS_TOKEN = \"</s>\"\n",
    "DEFAULT_UNK_TOKEN = \"<unk>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smart_tokenizer_and_embedding_resize(\n",
    "        special_tokens_dict: Dict,\n",
    "        tokenizer: transformers.PreTrainedTokenizer,\n",
    "        model: transformers.PreTrainedModel,\n",
    "):\n",
    "    \"\"\"Resize tokenizer and embedding.\n",
    "    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n",
    "    \"\"\"\n",
    "    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) \n",
    "    model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    if num_new_tokens > 0:\n",
    "        input_embeddings = model.get_input_embeddings().weight.data\n",
    "        output_embeddings = model.get_output_embeddings().weight.data\n",
    "        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)\n",
    "        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)\n",
    "\n",
    "        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n",
    "        output_embeddings[-num_new_tokens:] = output_embeddings_avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:32<00:00, 16.42s/it]\n",
      "Using pad_token, but it is not set yet.\n"
     ]
    }
   ],
   "source": [
    "def load_model_and_tokenizer(path='/mntcephfs/data/med/zhihong/workspace/LLMZoo/llama_hf_7b'):\n",
    "    model = AutoModelForCausalLM.from_pretrained(path)\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        path, \n",
    "        model_max_length=2048, \n",
    "        padding_side=\"right\", \n",
    "        use_fast=True\n",
    "    )\n",
    "    if tokenizer.pad_token is None:\n",
    "        smart_tokenizer_and_embedding_resize(\n",
    "            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),\n",
    "            tokenizer=tokenizer,\n",
    "            model=model,\n",
    "        )\n",
    "    tokenizer.add_special_tokens({\n",
    "        \"eos_token\": DEFAULT_EOS_TOKEN,\n",
    "        \"bos_token\": DEFAULT_BOS_TOKEN,\n",
    "        \"unk_token\": DEFAULT_UNK_TOKEN,\n",
    "    })\n",
    "    return tokenizer\n",
    "\n",
    "tokenizer = load_model_and_tokenizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_json_file(file_path):\n",
    "    with open(file_path, 'r', encoding=\"utf-8\") as file:\n",
    "        json_data = json.load(file)\n",
    "    return json_data\n",
    "\n",
    "def extract_utters(path):\n",
    "    all_utter = []\n",
    "    session_list = read_json_file(path)\n",
    "    for session_dict in session_list:\n",
    "        for utter_dict in session_dict['conversations']:\n",
    "            all_utter.append(utter_dict['value'])\n",
    "    return all_utter\n",
    "\n",
    "def compute_avg_utters(path):\n",
    "    '''所有utter的token的length / 所有utter的数量'''\n",
    "    utter_len = []\n",
    "    all_utter = extract_utters(path)\n",
    "    for utter in tqdm(all_utter):\n",
    "        idx_tensor = tokenizer(utter, return_tensors=\"pt\", padding=\"longest\")['input_ids'][0]\n",
    "        utter_len.append(len(idx_tensor)-1)\n",
    "    return np.mean(utter_len)\n",
    "\n",
    "def extract_conv(path):\n",
    "    all_conv = []\n",
    "    session_list = read_json_file(path)\n",
    "    for session_dict in session_list:\n",
    "        one_conv = ''\n",
    "        for utter_dict in session_dict['conversations']:\n",
    "            one_conv += (utter_dict['value'])\n",
    "        all_conv.append(one_conv) \n",
    "    return all_conv\n",
    "\n",
    "def compute_avg_convs(path):\n",
    "    all_conv = extract_conv(path)\n",
    "    one_conv_len = []\n",
    "    for one_conv_str in tqdm(all_conv):\n",
    "        idx_tensor = tokenizer(one_conv_str, return_tensors=\"pt\", padding=\"longest\")['input_ids'][0]\n",
    "        one_conv_len.append(len(idx_tensor)-1)\n",
    "    # assert len(one_conv_len) == 10000\n",
    "    return np.mean(one_conv_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = r\"/data/hei\"\n",
    "\n",
    "files_name = []\n",
    "for root, dirs, files in os.walk(directory):\n",
    "    for file in files:\n",
    "        files_name.append(os.path.join(root, file))\n",
    "\n",
    "for file in tqdm(files_name):\n",
    "    start = time.time()\n",
    "    \n",
    "    avg_utter_len = compute_avg_utters(file)\n",
    "    avg_conv_len = compute_avg_convs(file)\n",
    "    \n",
    "    end = time.time()\n",
    "    \n",
    "    print(f'\\nfor{file}:')\n",
    "    print('avg_utter_len_by_token: ', avg_utter_len)\n",
    "    print('avg_conv_len_by_token: ', avg_conv_len)\n",
    "    print(f'Elapsed {end-start} seconds.\\n')\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = r\"/data\"  \n",
    "files_name = []\n",
    "\n",
    "for root, dirs, files in os.walk(directory):\n",
    "    for file in files:\n",
    "        if \"10k\" in file:\n",
    "            files_name.append(os.path.join(root, file))\n",
    "            \n",
    "for file in tqdm(files_name):\n",
    "    start = time.time()\n",
    "    \n",
    "    avg_utter_len = compute_avg_utters(file)\n",
    "    avg_conv_len = compute_avg_convs(file)\n",
    "    \n",
    "    end = time.time()\n",
    "    \n",
    "    print(f'\\nfor{file}:')\n",
    "    print('avg_utter_len_by_token: ', avg_utter_len)\n",
    "    print('avg_conv_len_by_token: ', avg_conv_len)\n",
    "    print(f'Elapsed {end-start} seconds.\\n')\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
