{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "import torch\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5f6d5da8181344178c7758ba0613bf52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e7b75e326b74a45a812e032b0426e89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00000-of-00004.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c66763ecbad4cc6afa9fa62323a7be6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00001-of-00004.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90c93bf9f5044ab482e0eb73cf7eadbd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00002-of-00004.parquet:   0%|          | 0.00/253M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc32f9512d7244a8a740c8e7dd8e2555",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00003-of-00004.parquet:   0%|          | 0.00/253M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d44d6a3066214e21b32bfba829d8430a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/300000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pro = load_dataset('Magpie-Align/Magpie-Llama-3.1-Pro-300K-Filtered', cache_dir = '/workspace/CACHE/MODELS')\n",
    "tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')\n",
    "pro_tok_dict = {tok: 0 for tok, tok_id in tokenizer.vocab.items()}\n",
    "\n",
    "for idx in tqdm(range(len(pro['train']))):\n",
    "    sample = pro['train'][idx]\n",
    "    inst = sample['conversations'][0]['value']\n",
    "    resp = sample['conversations'][1]['value']\n",
    "    inst_toks = tokenizer.tokenize(inst)\n",
    "    resp_toks = tokenizer.tokenize(resp)\n",
    "\n",
    "    for tok in inst_toks:\n",
    "        pro_tok_dict[tok] += 1\n",
    "    for tok in resp_toks:\n",
    "        pro_tok_dict[tok] += 1\n",
    "\n",
    "total_dict = {tok: pro_tok_dict[tok] + reasoning_tok_dict[tok] for tok in pro_tok_dict.keys()}\n",
    "\n",
    "import json \n",
    "import os \n",
    "save_model_dir = \"Meta-Llama-3-8B-Instruct\"\n",
    "os.makedirs(f'./result/{save_model_dir}', exist_ok=True)\n",
    "with open(f'./result/{save_model_dir}/pro_tok_dict.json', 'w') as f:\n",
    "    json.dump(pro_tok_dict, f, ensure_ascii=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
