{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelzipper.tutils import *\n",
    "from pprint import pprint\n",
    "import re\n",
    "import datasets\n",
    "import numpy as np\n",
    "from modelzipper.tutils import *\n",
    "import random\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### process msimpo training data offline\n",
    "\n",
    "#### Preferred Data Format\n",
    "\n",
    "wrap_batch[\"concatenated_input_ids\"] = torch.tensor(batch[0][\"concatenated_input_ids\"])\n",
    "\n",
    "wrap_batch[\"concatenated_attention_mask\"] = torch.tensor(batch[0][\"concatenated_attention_mask\"])\n",
    "\n",
    "wrap_batch[\"concatenated_labels\"] = torch.tensor(batch[0][\"concatenated_labels\"])\n",
    "\n",
    "wrap_batch[\"position_ids\"] = torch.tensor(batch[0][\"position_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def create_position_ids(N, L):\n",
    "    \"\"\"sampling N points from L (max_chunk_size space)\"\"\"\n",
    "    if N == L:\n",
    "        start_pos = 0\n",
    "    else:\n",
    "        start_pos = np.random.randint(0, L - N)\n",
    "    end_pos = start_pos + N\n",
    "    position_ids = torch.arange(start_pos, end_pos)\n",
    "    return position_ids\n",
    "\n",
    "def create_covering_position_ids(N, L):\n",
    "    \"\"\"Create sets of position IDs to cover all positions from 0 to L-1 with intervals of length N.\"\"\"\n",
    "    if N > L:\n",
    "        raise ValueError(\"N should not be greater than L\")\n",
    "    \n",
    "    num_intervals = (L + N - 1) // N\n",
    "\n",
    "    position_ids_list = []\n",
    "    for i in range(num_intervals):\n",
    "        start_pos = i * (L - N) // (num_intervals - 1) if num_intervals > 1 else 0\n",
    "        end_pos = start_pos + N\n",
    "        if end_pos > L:\n",
    "            end_pos = L\n",
    "            start_pos = L - N if L > N else 0\n",
    "        position_ids = torch.arange(start_pos, end_pos)\n",
    "        position_ids_list.append(position_ids)\n",
    "\n",
    "    return position_ids_list\n",
    "\n",
    "def auto_padding(t: torch.Tensor, length: int, filling_value=-100, return_attention_mask=False):\n",
    "    if length < t.size(0):\n",
    "        if return_attention_mask: \n",
    "            return t[:length]\n",
    "        else: \n",
    "            return t[:length], torch.ones_like(t[:length])\n",
    "    padded_tensor = torch.full((length,), filling_value, dtype=t.dtype)\n",
    "    padded_tensor[:t.size(0)] = t\n",
    "    if return_attention_mask:\n",
    "        attention_mask = torch.zeros(length, dtype=torch.int)\n",
    "        attention_mask[:t.size(0)] = 1\n",
    "        return padded_tensor, attention_mask\n",
    "    return padded_tensor\n",
    "\n",
    "\n",
    "# block testing\n",
    "N = 5\n",
    "L = 120\n",
    "\n",
    "position_ids_list = create_covering_position_ids(N, L)\n",
    "print(position_ids_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_fn(lst, max_candidates=2):\n",
    "    trimmed_lists = [random.sample(sublst, min(len(sublst), max_candidates)) if len(sublst) > max_candidates else sublst for sublst in lst]\n",
    "    all_combinations = itertools.product(*trimmed_lists)\n",
    "    concatenated_results = [torch.cat(combination) for combination in all_combinations]\n",
    "    return concatenated_results\n",
    "\n",
    "def create_system_suffix(tokenizer, system_suffix, special_token_id: int=13):\n",
    "    tok_suffix = tokenizer(system_suffix, return_tensors=\"pt\", add_special_tokens=False)\n",
    "    padded_tok_suffix, padded_suffix_attention_mask = tok_suffix.input_ids[0], tok_suffix.attention_mask[0]\n",
    "    # add special token\n",
    "    padded_tok_suffix = torch.concatenate([padded_tok_suffix, torch.tensor([special_token_id])], dim=0)\n",
    "    padded_suffix_attention_mask = torch.concatenate([padded_suffix_attention_mask, torch.tensor([1])], dim=0)\n",
    "    system_position_ids = torch.arange(0, padded_tok_suffix.size(-1))\n",
    "    return padded_tok_suffix, padded_suffix_attention_mask, system_position_ids\n",
    "\n",
    "def create_chunked_reference(tokenizer: AutoTokenizer, all_refs: List[str], real_reference_size: int, max_embedding_size: int, system_prompt_size: int, qa_size: int, special_token_id: int=13):\n",
    "    real_max_chunk_size = real_reference_size // len(all_refs) - 1 # allocate one position for beacon token\n",
    "    tok_all_ref = [tokenizer(item, return_tensors=\"pt\", add_special_tokens=False).input_ids[0] for item in all_refs]\n",
    "    truncted_refer_tok_lst, statistic_data_size = [], []\n",
    "    for item in tok_all_ref:\n",
    "        statistic_data_size.append(item.size(-1))\n",
    "        if item.size(-1) > real_max_chunk_size: \n",
    "            item = item[: real_max_chunk_size]\n",
    "        truncted_refer_tok_lst.append(item)\n",
    "\n",
    "    real_max_chunk_size_w_beacon_token = real_max_chunk_size + 1  # with last special token index for each chunk\n",
    "    fake_position_chunk_size = (max_embedding_size - system_prompt_size - qa_size) // len(all_refs)\n",
    "    positional_chunks = torch.arange(system_prompt_size, max_embedding_size - qa_size, fake_position_chunk_size)\n",
    "    \n",
    "    # Here, end_positional_chunks denotes special token ids\n",
    "    begin_positional_chunks, end_positional_chunks = positional_chunks[:-1], positional_chunks[1:] - 1  \n",
    "    all_chunk_pos_lst = []\n",
    "    \n",
    "    for i, item in enumerate(truncted_refer_tok_lst):\n",
    "        chunk_token_pos_lst = create_covering_position_ids(item.size(-1), fake_position_chunk_size)\n",
    "        chunk_token_pos_lst = [item + begin_positional_chunks[i] for item in chunk_token_pos_lst]\n",
    "        all_chunk_pos_lst.append(chunk_token_pos_lst)\n",
    "\n",
    "    padded_chunk_pos_lst = [[auto_padding(sub_item, real_max_chunk_size, filling_value=0, return_attention_mask=False) for sub_item in item] for item in all_chunk_pos_lst]\n",
    "    padded_refer_tok_lst = [auto_padding(item, real_max_chunk_size, filling_value=0, return_attention_mask=True) for item in truncted_refer_tok_lst]\n",
    "    padded_refer_tok_ids = [item[0] for item in padded_refer_tok_lst]\n",
    "    padded_refer_attention_mask = [item[1] for item in padded_refer_tok_lst]\n",
    "\n",
    "    candicated_padded_position_ids = []\n",
    "    padded_ref_input_ids_lst, padded_ref_attention_mask_lst = [], []\n",
    "    \n",
    "    for chunk_pos_ids, chunk_spe_pos_lst in zip(end_positional_chunks, padded_chunk_pos_lst):\n",
    "        tmp_chunk_pos_ids = []\n",
    "        for tmp in chunk_spe_pos_lst:\n",
    "            tmp = torch.concatenate([tmp, torch.tensor([chunk_pos_ids])], dim=0)\n",
    "            tmp_chunk_pos_ids.append(tmp)  # [[0,1,...,C1], [C2,C2+1,...,C3], ...]\n",
    "        candicated_padded_position_ids.append(tmp_chunk_pos_ids)\n",
    "    candicated_padded_position_ids = combine_fn(candicated_padded_position_ids)\n",
    "\n",
    "    for padded_chunk_tok_ref_input_ids, padded_chunk_tok_ref_attention_mask in zip(padded_refer_tok_ids, padded_refer_attention_mask):\n",
    "        padded_chunk_tok_ref_input_ids = torch.concatenate([padded_chunk_tok_ref_input_ids, torch.tensor([special_token_id])], dim=0)\n",
    "        padded_chunk_tok_ref_attention_mask = torch.concatenate([padded_chunk_tok_ref_attention_mask, torch.tensor([1])], dim=0)\n",
    "        padded_ref_input_ids_lst.append(padded_chunk_tok_ref_input_ids)\n",
    "        padded_ref_attention_mask_lst.append(padded_chunk_tok_ref_attention_mask)\n",
    "    \n",
    "    padded_ref_input_ids = torch.concatenate(padded_ref_input_ids_lst, dim=0)\n",
    "    padded_ref_attention_mask = torch.concatenate(padded_ref_attention_mask_lst, dim=0)\n",
    "    all_spe_pos = torch.arange(real_max_chunk_size, real_reference_size, real_max_chunk_size + 1)\n",
    "\n",
    "    return candicated_padded_position_ids, padded_ref_input_ids, padded_ref_attention_mask, all_spe_pos\n",
    "\n",
    "\n",
    "def create_qa(QUESTION_TEMPLATE, ANSWER_TEMPLATE, combined_question, combined_answer, prefix_a: str, suffix_a: str, last_position: int, qa_size: int, special_token_id: int):\n",
    "    # Create Question\n",
    "    question = QUESTION_TEMPLATE.format(question=combined_question)\n",
    "    tok_question = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids[0]\n",
    "    padded_tok_question, padded_question_attention_mask = auto_padding(tok_question, tok_question.size(-1), filling_value=0, return_attention_mask=True)\n",
    "    padded_tok_question = torch.concatenate([padded_tok_question, torch.tensor([special_token_id])], dim=0)\n",
    "    padded_question_attention_mask = torch.concatenate([padded_question_attention_mask, torch.tensor([1])], dim=0)\n",
    "    question_position_input_ids = create_position_ids(tok_question.size(-1), tok_question.size(-1)) + last_position + 1\n",
    "    last_pos = question_position_input_ids.max() + 1\n",
    "    question_position_input_ids = torch.concatenate([question_position_input_ids, torch.tensor([last_pos])], dim=0)\n",
    "    spe_tok_pos = question_position_input_ids.size(-1) - 1\n",
    "\n",
    "    # Create Chosen / Rejected Answers / and their labels\n",
    "    chosen_answer = ANSWER_TEMPLATE.format(answer=combined_answer)\n",
    "    prefix_rejected_answer = ANSWER_TEMPLATE.format(answer=prefix_a)\n",
    "    suffix_rejected_answer = ANSWER_TEMPLATE.format(answer=suffix_a)\n",
    "    tok_chosen_answer = tokenizer(chosen_answer, return_tensors=\"pt\", add_special_tokens=False).input_ids[0]\n",
    "    tok_prefix_rejected_answer = tokenizer(prefix_rejected_answer, return_tensors=\"pt\", add_special_tokens=False).input_ids[0]\n",
    "    tok_suffix_rejected_answer = tokenizer(suffix_rejected_answer, return_tensors=\"pt\", add_special_tokens=False).input_ids[0]\n",
    "\n",
    "    system_reference_question_size = last_position + 1 + question_position_input_ids.size(-1)\n",
    "\n",
    "    padded_tok_chosen_answer, padded_chosen_answer_attention_mask = auto_padding(tok_chosen_answer, qa_size-padded_question_attention_mask.size(-1), filling_value=0, return_attention_mask=True)\n",
    "    tok_chosen_answer_labels = auto_padding(tok_chosen_answer, qa_size - padded_question_attention_mask.size(-1), filling_value=-100)\n",
    "    chosen_answer_position_ids = create_position_ids(tok_chosen_answer.size(-1), tok_chosen_answer.size(-1)) + system_reference_question_size\n",
    "    chosen_answer_position_ids = auto_padding(chosen_answer_position_ids, qa_size - padded_question_attention_mask.size(-1), filling_value=0)\n",
    "\n",
    "    padded_tok_prefix_rejected_answer, padded_prefix_rejected_answer_attention_mask = auto_padding(tok_prefix_rejected_answer, qa_size - padded_question_attention_mask.size(-1), filling_value=0, return_attention_mask=True)\n",
    "    tok_prefix_rejected_answer_labels = auto_padding(tok_prefix_rejected_answer, qa_size - padded_question_attention_mask.size(-1), filling_value=-100)\n",
    "    prefix_rejected_answer_position_ids = create_position_ids(tok_prefix_rejected_answer.size(-1), tok_prefix_rejected_answer.size(-1)) + system_reference_question_size\n",
    "    prefix_rejected_answer_position_ids = auto_padding(prefix_rejected_answer_position_ids, qa_size - padded_question_attention_mask.size(-1), filling_value=0)\n",
    "    \n",
    "    padded_tok_suffix_rejected_answer, padded_suffix_rejected_answer_attention_mask = auto_padding(tok_suffix_rejected_answer, qa_size - padded_question_attention_mask.size(-1), filling_value=0, return_attention_mask=True)\n",
    "    tok_suffix_rejected_answer_labels = auto_padding(tok_suffix_rejected_answer, qa_size - padded_question_attention_mask.size(-1), filling_value=-100)\n",
    "    suffix_rejected_answer_position_ids = create_position_ids(tok_suffix_rejected_answer.size(-1), tok_suffix_rejected_answer.size(-1)) + system_reference_question_size\n",
    "    suffix_rejected_answer_position_ids = auto_padding(suffix_rejected_answer_position_ids, qa_size - padded_question_attention_mask.size(-1), filling_value=0)\n",
    " \n",
    "    return padded_tok_question, padded_question_attention_mask, question_position_input_ids, \\\n",
    "        padded_tok_chosen_answer, padded_chosen_answer_attention_mask, tok_chosen_answer_labels, chosen_answer_position_ids, \\\n",
    "        padded_tok_prefix_rejected_answer, padded_prefix_rejected_answer_attention_mask, tok_prefix_rejected_answer_labels, prefix_rejected_answer_position_ids, \\\n",
    "        padded_tok_suffix_rejected_answer, padded_suffix_rejected_answer_attention_mask, tok_suffix_rejected_answer_labels, suffix_rejected_answer_position_ids, spe_tok_pos\n",
    "    \n",
    "\n",
    "\"\"\"block testing create_chunked_reference\"\"\" \n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"Meta-Llama-3-8B-Instruct\")\n",
    "all_refs = [\"hello, world\", \"Any, iowpq\", \"reason medsa\"]\n",
    "real_reference_size = 28\n",
    "system_prompt_size, qa_size = 2, 4\n",
    "max_embedding_size = 1340\n",
    "candicated_padded_concat_position_ids, padded_ref_input_ids, padded_ref_attention_mask, all_spe_pos = create_chunked_reference(tokenizer, all_refs, real_reference_size, max_embedding_size, system_prompt_size, qa_size)\n",
    "\n",
    "\"\"\"Test Create QA Function\"\"\"\n",
    "QUESTION_TEMPLATE = \"<|start_header_id|>user<|end_header_id|>\\n\\nPlease answer the following question according to the references: {question}<|eot_id|>\"\n",
    "ANSWER_TEMPLATE = \"<|start_header_id|>assistant<|end_header_id|>\\n\\nThe answer is: {answer}<|eot_id|><|end_of_text|>\"\n",
    "\n",
    "# block testing\n",
    "padded_tok_question, padded_question_attention_mask, \\\n",
    "question_position_input_ids, padded_tok_chosen_answer, \\\n",
    "padded_chosen_answer_attention_mask, tok_chosen_answer_labels, \\\n",
    "chosen_answer_position_ids, padded_tok_prefix_rejected_answer, \\\n",
    "padded_prefix_rejected_answer_attention_mask, tok_prefix_rejected_answer_labels, \\\n",
    "prefix_rejected_answer_position_ids, padded_tok_suffix_rejected_answer, \\\n",
    "padded_suffix_rejected_answer_attention_mask, tok_suffix_rejected_answer_labels, \\\n",
    "suffix_rejected_answer_position_ids, spe_tok_pos = create_qa(QUESTION_TEMPLATE, ANSWER_TEMPLATE, \"who are you\", \"jack\", \"prefix_a\", \"suffix_a\", last_position=1024, qa_size=100, special_token_id=13)\n",
    "\n",
    "print(padded_tok_question.shape)\n",
    "print(padded_question_attention_mask.shape)\n",
    "print(question_position_input_ids.shape)\n",
    "print(padded_tok_chosen_answer.shape)\n",
    "print(chosen_answer_position_ids.shape)\n",
    "print(prefix_rejected_answer_position_ids.shape)\n",
    "print(padded_tok_suffix_rejected_answer.shape)\n",
    "print(padded_suffix_rejected_answer_attention_mask.shape)\n",
    "print(suffix_rejected_answer_position_ids.shape)\n",
    "print(tok_suffix_rejected_answer_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_index(id_lst, prefix_id, suffix_id):\n",
    "    return id_lst.index(prefix_id), id_lst.index(suffix_id)\n",
    "\n",
    "def create_covering_position_ids(N, L):\n",
    "    \"\"\"Create sets of position IDs to cover all positions from 0 to L-1 with intervals of length N.\"\"\"\n",
    "    if N > L:\n",
    "        raise ValueError(\"N should not be greater than L\")\n",
    "    num_intervals = (L + N - 1) // N\n",
    "    position_ids_list = []\n",
    "    for i in range(num_intervals):\n",
    "        start_pos = i * (L - N) // (num_intervals - 1) if num_intervals > 1 else 0\n",
    "        end_pos = start_pos + N\n",
    "        if end_pos > L:\n",
    "            end_pos = L\n",
    "            start_pos = L - N if L > N else 0\n",
    "        position_ids = torch.arange(start_pos, end_pos)\n",
    "        position_ids_list.append(position_ids)\n",
    "    return position_ids_list\n",
    "\n",
    "\n",
    "def auto_padding(t: torch.Tensor, length: int, filling_value=-100, return_attention_mask=False):\n",
    "    if length < t.size(0):\n",
    "        if return_attention_mask: return t[:length]\n",
    "        else: return t[:length], torch.ones_like(t[:length])\n",
    "    padded_tensor = torch.full((length,), filling_value, dtype=t.dtype)\n",
    "    padded_tensor[:t.size(0)] = t\n",
    "    if return_attention_mask:\n",
    "        attention_mask = torch.zeros(length, dtype=torch.int)\n",
    "        attention_mask[:t.size(0)] = 1\n",
    "        return padded_tensor, attention_mask\n",
    "    return padded_tensor\n",
    "\n",
    "\n",
    "def create_covering_position_ipt_data(tokenizer, all_refs: List[str], combined_question: str, combined_answer: str, prefix_a: str, suffix_a: str, qa_size: int, max_embedding_size: int, real_reference_size: int, special_token_id: int = None, prefix_id: int = None, suffix_id: int = None):\n",
    "    statistic_data_size = []\n",
    "\n",
    "    SYSTEM_SUFFIX = \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nBelow is some references. Please read it carefully and answer the following question.<|eot_id|>\"\n",
    "    QUESTION_TEMPLATE = \"<|start_header_id|>user<|end_header_id|>\\n\\nPlease answer the following question according to the references: {question}<|eot_id|>\"\n",
    "    ANSWER_TEMPLATE = \"<|start_header_id|>assistant<|end_header_id|>\\n\\nThe answer is: {answer}<|eot_id|><|end_of_text|>\"\n",
    "\n",
    "    # Create System Suffix\n",
    "    padded_tok_input_ids_system_suffix, padded_attention_mask_system_suffix, padded_position_ids_system_suffix = create_system_suffix(tokenizer, SYSTEM_SUFFIX, special_token_id)\n",
    "    system_prompt_size = padded_attention_mask_system_suffix.size(-1)\n",
    "    all_spe_pos = [system_prompt_size-1]\n",
    "    # create chunk reference (input_ids, attention_mask and positional ids)\n",
    "    candicated_padded_position_ids_lst, padded_ref_input_ids, padded_ref_attention_mask, ref_spe_pos = create_chunked_reference(tokenizer, all_refs, real_reference_size, max_embedding_size, system_prompt_size, qa_size, special_token_id)\n",
    "\n",
    "    ref_spe_pos += system_prompt_size\n",
    "    all_spe_pos.extend(ref_spe_pos.tolist())\n",
    "\n",
    "    # combine and wrap each position_id, input_ids and attention_mask\n",
    "    last_position = max_embedding_size - qa_size  # size for real reference and system prompt\n",
    "\n",
    "    # Create Question, all Answers\n",
    "    padded_tok_question, padded_question_attention_mask, \\\n",
    "    question_position_input_ids, padded_tok_chosen_answer, \\\n",
    "    padded_chosen_answer_attention_mask, tok_chosen_answer_labels, \\\n",
    "    chosen_answer_position_ids, padded_tok_prefix_rejected_answer, \\\n",
    "    padded_prefix_rejected_answer_attention_mask, tok_prefix_rejected_answer_labels, \\\n",
    "    prefix_rejected_answer_position_ids, padded_tok_suffix_rejected_answer, \\\n",
    "    padded_suffix_rejected_answer_attention_mask, tok_suffix_rejected_answer_labels, \\\n",
    "    suffix_rejected_answer_position_ids, spe_tok_pos = create_qa(\n",
    "        QUESTION_TEMPLATE, ANSWER_TEMPLATE, combined_question, combined_answer, prefix_a, suffix_a, last_position, qa_size, special_token_id=special_token_id\n",
    "    )\n",
    "    all_spe_pos.append(spe_tok_pos + system_prompt_size + padded_ref_input_ids.size(-1))\n",
    "    \n",
    "    # Merge All the Inputs Data\n",
    "    all_datasets = []  # different combination of positions \n",
    "   \n",
    "    for i, ref_position_id in enumerate(candicated_padded_position_ids_lst):\n",
    "        concatenated_batch = {}\n",
    "        concatenated_batch[\"input_ids\"] = torch.concatenate([padded_tok_input_ids_system_suffix, padded_ref_input_ids, padded_tok_question], dim=0)\n",
    "        concatenated_batch[\"attention_mask\"] = torch.concatenate([padded_attention_mask_system_suffix, padded_ref_attention_mask, padded_question_attention_mask], dim=0)\n",
    "        concatenated_batch[\"position_ids\"] = torch.concatenate([padded_position_ids_system_suffix, ref_position_id, question_position_input_ids], dim=0)\n",
    "        referece_question_length = concatenated_batch[\"attention_mask\"].size(-1)\n",
    "        concatenated_batch[\"all_spe_pos\"] = all_spe_pos\n",
    "        referece_question_labels = torch.full((1, referece_question_length), -100)[0]\n",
    "        \n",
    "        # Create Labels for Each Part\n",
    "        concatenated_batch[\"chosen_answer\"] = {\n",
    "            \"input_ids\": padded_tok_chosen_answer, \n",
    "            \"attention_mask\": padded_chosen_answer_attention_mask, \n",
    "            \"labels\": torch.concatenate([referece_question_labels, tok_chosen_answer_labels], dim=0),\n",
    "            \"position_ids\": chosen_answer_position_ids\n",
    "        }\n",
    "        concatenated_batch[\"prefix_rejected_answer\"] = {\n",
    "            \"input_ids\": padded_tok_prefix_rejected_answer, \n",
    "            \"attention_mask\": padded_prefix_rejected_answer_attention_mask, \n",
    "            \"labels\": torch.concatenate([referece_question_labels, tok_prefix_rejected_answer_labels], dim=0),\n",
    "            \"position_ids\": prefix_rejected_answer_position_ids,\n",
    "        }\n",
    "        concatenated_batch[\"suffix_rejected_answer\"] = {\n",
    "            \"input_ids\": padded_tok_suffix_rejected_answer, \n",
    "            \"attention_mask\": padded_suffix_rejected_answer_attention_mask, \n",
    "            \"labels\": torch.concatenate([referece_question_labels, tok_suffix_rejected_answer_labels], dim=0),\n",
    "            \"position_ids\": suffix_rejected_answer_position_ids,\n",
    "        }\n",
    "        concatenated_batch[\"chosen_ids\"] = (prefix_id, suffix_id)\n",
    "        all_datasets.append(concatenated_batch)\n",
    "        statistic_data_size.append(concatenated_batch[\"input_ids\"].size(-1))\n",
    "\n",
    "    return all_datasets, sum(statistic_data_size) / len(statistic_data_size)\n",
    "\n",
    "\n",
    "dataset = datasets.load_from_disk(\"vllm_data\")\n",
    "print(dataset)\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"Meta-Llama-3-8B-Instruct\")\n",
    "training_samples = []\n",
    "avg_real_seq_length = 0\n",
    "spe_token_id = tokenizer(\"<|reserved_special_token_0|>\", add_special_tokens=False).input_ids[0]\n",
    "with tqdm(total=len(dataset), desc=f\"Initial Avg Length: {avg_real_seq_length}\") as pbar:\n",
    "    for item in dataset:\n",
    "        all_ref_text = item[\"all_ref_text\"]\n",
    "        combined_question, final_answer = item[\"combined_question\"], item[\"final_answer\"]\n",
    "        prefix_q, suffix_q = item[\"prefix_q\"], item[\"suffix_q\"]\n",
    "        prefix_a, suffix_a = item[\"prefix_a\"], item[\"suffix_a\"]\n",
    "        prefix_id, suffix_id = find_index(item[\"all_ref_ids\"], item[\"prefix_id\"], item[\"suffix_id\"])\n",
    "        all_datasets, ref_length = create_covering_position_ipt_data(tokenizer, all_ref_text, combined_question, final_answer, prefix_a, suffix_a, qa_size=512, max_embedding_size=65536, real_reference_size=8192, special_token_id=spe_token_id, prefix_id=prefix_id, suffix_id=suffix_id)\n",
    "        avg_real_seq_length += ref_length / len(dataset)\n",
    "        training_samples.extend(all_datasets)\n",
    "        pbar.set_description(f\"Current Avg Seq Length: {avg_real_seq_length:.2f}\")\n",
    "        pbar.update(1)\n",
    "\n",
    "print(len(training_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "converted_dict = {}\n",
    "for item in training_samples:\n",
    "    for key, value in item.items():\n",
    "        if key in converted_dict:\n",
    "            converted_dict[key].append(value)\n",
    "        else:\n",
    "            converted_dict[key] = [value]\n",
    "\n",
    "hf_datasets = datasets.Dataset.from_dict(converted_dict)\n",
    "hf_datasets.save_to_disk(\"/data/save_path\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zecheng",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
