{
 "cells": [
  {
   "cell_type": "code",
   "id": "0a5ee93a-085a-402e-9a2b-37e512c5ea09",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T07:44:48.786111Z",
     "start_time": "2025-03-13T07:44:48.780222Z"
    }
   },
   "source": [
    "import xml.etree.ElementTree as ET\n",
    "import re\n",
    "\n",
    "def find_elements_by_context_ref(xml_file, context_id):\n",
    "    try:\n",
    "        tree = ET.parse(xml_file)\n",
    "        root = tree.getroot()\n",
    "\n",
    "        matching_elements = []\n",
    "        for element in root.iter():\n",
    "\n",
    "            if element.get(\"contextRef\") == context_id and \"us-gaap\" in element.tag:\n",
    "                truncated_content = element.text[:100] if element.text else \"\"  # Truncate content\n",
    "                element.text = truncated_content\n",
    "                \n",
    "                ele = ET.tostring(element, encoding=\"unicode\").replace(\"ns0\", \"us-gaap\")\n",
    "                if \"TextBlock\" in ele or \"style=\" in ele:\n",
    "                    continue\n",
    "\n",
    "                ele = ele.replace('xmlns:us-gaap=\"http://fasb.org/us-gaap/2023\"', \"\").replace(f'contextRef=\"{context_id}\"', \"\") \n",
    "                ele = re.sub(r\"</.*?>\", \"</>\", ele)  # Remove closing tag text (to reduce token count)\n",
    "                ele = re.sub(r\"\\w+=\\\".*?\\\"\", \"\", ele)  # Remove attributes\n",
    "                ele = re.sub(r\"\\s+\", \" \", ele)  # Remove consecutive spaces\n",
    "\n",
    "                matching_elements.append(ele)\n",
    "\n",
    "        return \"\\n\".join(matching_elements)\n",
    "\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: XML file not found: {xml_file}\")\n",
    "        return \"\"\n",
    "\n",
    "# replace the file name with xbrl raw text\n",
    "def add_xml(qa_string, limit=1000000):\n",
    "    if '<' not in qa_string or ',id:' not in qa_string:\n",
    "        return qa_string\n",
    "\n",
    "    # Extract information from the QA string\n",
    "    start = qa_string.find(\"<\") + 1\n",
    "    end = qa_string.find(\">\")\n",
    "    placeholder = qa_string[start:end]\n",
    "    parts = placeholder.split(\",id:\")\n",
    "    doc_path = \"train/DowJones30/\" + parts[0]\n",
    "    \n",
    "    context_id = parts[1]\n",
    "\n",
    "    # Get the XML content using the custom grep function\n",
    "    xml_content = find_elements_by_context_ref(doc_path, context_id)[:limit]\n",
    "\n",
    "    # Replace the placeholder with the XML content\n",
    "    new_qa_string = qa_string.replace(f\"<{placeholder}>\", xml_content + \"\\n\\n\")\n",
    "    return new_qa_string\n"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "id": "1773b3dc-155a-45a2-bf10-09fa1927513d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T07:44:48.811327Z",
     "start_time": "2025-03-13T07:44:48.805719Z"
    }
   },
   "source": [
    "import json\n",
    "from typing import List, Dict\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "import random\n",
    "import os.path\n",
    "\n",
    "random.seed(42)\n",
    "\n",
    "def get_xbrl_dataset(data: List[Dict], example_q=None, example_a=None):\n",
    "    \"\"\"\n",
    "    Saves entries with matching category1 or category2 in the format for fine-tuning.\n",
    "\n",
    "    Args:\n",
    "        data (List[Dict]): The input JSON data.\n",
    "        category (str): The category name to match.\n",
    "        output_file (str): The output file path.\n",
    "    \"\"\"\n",
    "\n",
    "    results = {}\n",
    "    for entry in tqdm(data):\n",
    "        if (entry[\"doc_path\"], entry[\"answer\"], entry[\"contextID\"][0]) in results.keys():\n",
    "            continue\n",
    "\n",
    "        question = entry[\"query\"]\n",
    "        question = re.sub(r\"\\(.*?\\)\", \"\", question)\n",
    "        doc_path = entry[\"doc_path\"]\n",
    "        context_ids = entry[\"contextID\"]\n",
    "\n",
    "        if not os.path.isfile('train/DowJones30/' + doc_path):\n",
    "            # print(f\"missing file {doc_path}\")\n",
    "            continue\n",
    "\n",
    "        example_qa = \"\"\n",
    "        if example_q is not None and example_a is not None:\n",
    "            example_qa = f\"\\nExample question: {example_q}\\nExample answer: {example_a}\"\n",
    "        target = entry[\"raw_answer\"]\n",
    "\n",
    "        if entry['category1'] == 'formula_calculation' or entry['category2'] == 'formula_calculation':\n",
    "            question += \" Answer with a formula substituted with values. \"\n",
    "            target = entry[\"value_formula_answer\"]\n",
    "\n",
    "        context = \\\n",
    "            f\"\"\"\"You are a knowledgeable XBRL assistant that can answer questions based on XML data. \n",
    "             You will be provided with a context extracted from an XBRL file and a question related to it. The example question can help you to learn the format of the answer.\n",
    "             Your task is to analyze the XBRL context and provide an accurate and very concise answer to the question, DO NOT output xml, code, explanation or create new question.\n",
    "            \\nXBRL file:\\n ```xml\\n <{doc_path},id:{context_ids[0]}> ```\\n\n",
    "            {example_qa}\n",
    "            \\nQuestion: {question}\n",
    "            \\nAnswer:\"\"\"\n",
    "\n",
    "        context_xml = add_xml(context)\n",
    "        if len(context_xml) > 24000:\n",
    "            continue\n",
    "\n",
    "\n",
    "        # print(entry[\"answer\"])\n",
    "        # entry[\"doc_path\"], entry[\"answer\"], entry[\"contextID\"][0]\n",
    "        results[entry[\"doc_path\"], entry[\"answer\"], entry[\"contextID\"][0]] = {\"context\": context_xml,\n",
    "                                                                              \"target\": str(target),\n",
    "                                                                              \"doc_path\": entry['doc_path']}\n",
    "\n",
    "    print(\"final length\", len(results))\n",
    "    return list(results.values())\n",
    "\n",
    "\n",
    "def gen_xbrl(cat, example_q, example_a):\n",
    "    with open(\"xbrl_bench_34020.json.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "        filtered_data = [entry for entry in data if entry['category1'] == cat or entry['category2'] == cat]\n",
    "\n",
    "        all_doc_path = list(set([entry['doc_path'] for entry in filtered_data]))\n",
    "        print(f\"Total data size for this {cat}: {len(filtered_data)}, total number of filings {len(all_doc_path)}\")\n",
    "        random.shuffle(filtered_data)\n",
    "\n",
    "        # train_data = filtered_data[split_size:]\n",
    "        # train_data = train_data\n",
    "\n",
    "        dataset = get_xbrl_dataset(filtered_data[:2500], example_q, example_a)\n",
    "        dataset = dataset[:1500]\n",
    "        test_data = []\n",
    "        train_data = []\n",
    "        random.shuffle(all_doc_path)\n",
    "        for x in all_doc_path:\n",
    "            portion = [entry for entry in dataset if entry[\"doc_path\"] == x]\n",
    "            if len(test_data) < 100:\n",
    "                test_data += portion\n",
    "            else:\n",
    "                train_data += portion\n",
    "\n",
    "        return train_data, test_data\n"
   ],
   "outputs": [],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "id": "e6d109f6-b096-43b0-82d5-de5a19c43925",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-03-13T07:45:01.515545Z",
     "start_time": "2025-03-13T07:44:48.817140Z"
    }
   },
   "source": [
    "tags_train, tags_test = gen_xbrl(\"xbrl_tags\", \n",
    "         example_q = \"What is the US GAAP XBRL tag for Cash and Cash Equivalents as reported by Example Company Inc for the Fiscal Year ending in FY 2022\", \n",
    "         example_a = \"us-gaap:AnExampleTagName\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total data size for this xbrl_tags: 2730, total number of filings 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [00:12<00:00, 199.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final length 546\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "id": "81b08ecc-8809-4159-bc36-52c8accab3c0",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-03-13T07:46:06.583121Z",
     "start_time": "2025-03-13T07:45:01.521903Z"
    }
   },
   "source": [
    "value_train, value_test = gen_xbrl(\"value\", \n",
    "         example_q = \"What is the value of Exapmle company's income for the Fiscal year ending in FY 2020?\", \n",
    "         example_a = \"2540000000\")\n",
    "\n",
    "\n",
    "formula_train, formula_test = gen_xbrl(\"formula_calculation\",\n",
    "         example_q = \"Can you provide the formula for Operating Profit Margin from Example Corp for the Fiscal Year ending in FY 2022?\",\n",
    "         example_a = \"(50000000 / 3590000000) * 100\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total data size for this value: 12600, total number of filings 150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [00:43<00:00, 56.89it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final length 1586\n",
      "Total data size for this formula_calculation: 4195, total number of filings 150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [00:20<00:00, 119.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final length 774\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "id": "d02a6391-9bae-4608-88aa-1a529d7f7bb1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T07:46:06.737070Z",
     "start_time": "2025-03-13T07:46:06.734835Z"
    }
   },
   "source": [
    "train = tags_train * 3 + value_train + formula_train\n",
    "test = tags_test + value_test + formula_test"
   ],
   "outputs": [],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "id": "e7849df2-7b34-4c6c-8229-cb7794821b15",
   "metadata": {},
   "source": [
    "Now After combining train from both dataset there might be repeated filings, remove all questions in train where the filings existed in testing"
   ]
  },
  {
   "cell_type": "code",
   "id": "e015c68e-82c9-4062-86a3-35c5abe0f267",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T07:46:06.878528Z",
     "start_time": "2025-03-13T07:46:06.871448Z"
    }
   },
   "source": [
    "def check_and_remove_repeat(train, test):\n",
    "    print(\"train lenth:\", len(train), \"test length\", len(test))\n",
    "    train_doc_path = set([entry['doc_path'] for entry in train])\n",
    "    test_doc_path = set([entry['doc_path'] for entry in test])\n",
    "    repeated = list(train_doc_path.intersection(test_doc_path))\n",
    "    print(\"number of repeated filings between train/test:\", len(repeated))\n",
    "    train = [x for x in train if x['doc_path'] not in repeated]\n",
    "    return train, test\n",
    "    \n",
    "train, test = check_and_remove_repeat(train, test)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train lenth: 3380 test length 316\n",
      "number of repeated filings between train/test: 32\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "id": "4c0fa7db-1f5f-4e78-ad4c-1183b1849f32",
   "metadata": {},
   "source": [
    "Check again and save"
   ]
  },
  {
   "cell_type": "code",
   "id": "532669d8-daf4-455b-ad4a-d4fce5f99cc1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T07:46:07.110552Z",
     "start_time": "2025-03-13T07:46:06.980261Z"
    }
   },
   "source": [
    "check_and_remove_repeat(train, test)\n",
    "\n",
    "\n",
    "with open(f\"train/xbrl_train.jsonl\", \"w\") as f_train:\n",
    "    for example in train:\n",
    "        f_train.write(json.dumps(example) + \"\\n\")\n",
    "\n",
    "with open(f\"test/xbrl_xbrl_tags_test.jsonl\", \"w\") as f_test:\n",
    "    for example in tags_test:\n",
    "        f_test.write(json.dumps(example) + \"\\n\")\n",
    "\n",
    "with open(f\"test/xbrl_value_test.jsonl\", \"w\") as f_test:\n",
    "    for example in value_test:\n",
    "        f_test.write(json.dumps(example) + \"\\n\")\n",
    "\n",
    "with open(f\"test/xbrl_formula_test.jsonl\", \"w\") as f_test:\n",
    "    for example in formula_test:\n",
    "        f_test.write(json.dumps(example) + \"\\n\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train lenth: 2800 test length 316\n",
      "number of repeated filings between train/test: 0\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "id": "a0fdc778-e767-4a54-a284-eafff12ff37b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T07:46:07.125266Z",
     "start_time": "2025-03-13T07:46:07.124182Z"
    }
   },
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
