{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "noise_index = 5\n",
    "with open(f'/home//StepByStep/results/reasoning/combined/{noise_index}/accuracy/Activation_example_start_5_end_16.pickle', 'rb') as f:\n",
    "    input_ids = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = 8\n",
    "with open(f'result/{noise_index}/example_{example}.pickle', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_id = input_ids['input_ids'][example - 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
    "MODEL_PATH = '/home/models//Llama-2-7b-hf'\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = []\n",
    "for i in range(len(input_id)):\n",
    "    tokens.append(f'{tokenizer.decode(input_id[i])}_{i}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save to file txt\n",
    "with open(f'result/{noise_index}/example/example_{example}_tokens.txt', 'w') as f:\n",
    "    for token in tokens:\n",
    "        f.write(token + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_data = {}\n",
    "for d in data:\n",
    "    attend = []\n",
    "    for d1 in d['attending_layer_head']:\n",
    "        token = tokenizer.decode(input_id[d1['pos']])\n",
    "        token_passed = tokenizer.decode(d1['input_id'])\n",
    "        attend.append(f'layer_{d1[\"layer\"]}_head_{d1[\"head\"]}_token_{token}_pos_{d1[\"pos\"]}_tokenPassed_{token_passed}')\n",
    "    token = tokenizer.decode(input_id[d['pos']])\n",
    "    clean_data[f'layer_{d[\"layer\"]}_head_{d[\"head\"]}_token_{token}_pos_{d[\"pos\"]}'] = attend\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "depth_0_entry = []\n",
    "for d in data:\n",
    "    if d['depth'] == 0:\n",
    "        depth_0_entry.append(f'layer_{d[\"layer\"]}_head_{d[\"head\"]}_token_{tokenizer.decode(input_id[d[\"pos\"]])}_pos_{d[\"pos\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['layer_21_head_26_token_is_pos_-1',\n",
       " 'layer_21_head_27_token_is_pos_-1',\n",
       " 'layer_22_head_16_token_is_pos_-1',\n",
       " 'layer_31_head_27_token_is_pos_-1']"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "depth_0_entry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "nested = {}\n",
    "\n",
    "def nested_data(data, entry):\n",
    "    nested_dict = {}\n",
    "    new_entry = entry.split('_tokenPassed')[0]\n",
    "    # print(new_entry)\n",
    "    if(new_entry not in data):\n",
    "        # print(\"Not present\")\n",
    "        return None\n",
    "    for d in data[new_entry]:\n",
    "        nested_dict[d] = nested_data(data, d)\n",
    "    return nested_dict\n",
    "\n",
    "for entry in depth_0_entry:\n",
    "    nested[entry] = nested_data(clean_data, entry)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open(f'result/{noise_index}/example/example_{example}_nested.json', 'w') as f:\n",
    "    json.dump(nested, f, indent=4, sort_keys=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# nested_data = []\n",
    "# for data in clean_data:\n",
    "#     for entry in data:\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "nested_json = {}\n",
    "def create_nested_json(data):\n",
    "    temp_json = {}\n",
    "    for key, value in data.items():\n",
    "        nested_json[key] = []\n",
    "        if len(value) == 0:\n",
    "            temp_json[key] = value\n",
    "        else:\n",
    "            nested_json[key].append(create_nested_json(value))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'list' object has no attribute 'items'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mcreate_nested_json\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclean_data\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[14], line 9\u001b[0m, in \u001b[0;36mcreate_nested_json\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m      7\u001b[0m     temp_json[key] \u001b[38;5;241m=\u001b[39m value\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m----> 9\u001b[0m     nested_json[key]\u001b[38;5;241m.\u001b[39mappend(\u001b[43mcreate_nested_json\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m)\n",
      "Cell \u001b[0;32mIn[14], line 4\u001b[0m, in \u001b[0;36mcreate_nested_json\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_nested_json\u001b[39m(data):\n\u001b[1;32m      3\u001b[0m     temp_json \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m----> 4\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m \u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m():\n\u001b[1;32m      5\u001b[0m         nested_json[key] \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      6\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(value) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'items'"
     ]
    }
   ],
   "source": [
    "create_nested_json(clean_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_nested_dict(data_list):\n",
    "    nested_dict = {}\n",
    "    for entry in data_list:\n",
    "        key = f\"layer_{entry['layer']}_head_{entry['head']}_pos_{entry['pos']}\"\n",
    "        print(entry)\n",
    "        if 'attending_layer_head' in entry:\n",
    "            nested_dict[key] = [create_nested_dict(entry['attending_layer_head'])]\n",
    "        else:\n",
    "            nested_dict[key] = []\n",
    "    return nested_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_nested_dict(clean_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nested_dict(node):\n",
    "    result = {}\n",
    "    key = f'layer_{node[\"layer\"]}_head_{node[\"head\"]}_pos_{node[\"pos\"]}'\n",
    "    children = node.get('attending_layer_head', [])\n",
    "    if children:\n",
    "        result[key] = {nested_dict(child) for child in children}\n",
    "    else:\n",
    "        result[key] = {}\n",
    "    return result\n",
    "\n",
    "nested_data = [nested_dict(entry) for entry in clean_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting anytree\n",
      "  Downloading anytree-2.12.1-py3-none-any.whl (44 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.9/44.9 kB\u001b[0m \u001b[31m263.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: six in /home//venv/lib/python3.10/site-packages (from anytree) (1.16.0)\n",
      "Installing collected packages: anytree\n",
      "Successfully installed anytree-2.12.1\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install anytree\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from anytree import Node, RenderTree, AsciiStyle\n",
    "\n",
    "def create_tree(data, parent=None):\n",
    "    if not isinstance(data, dict):\n",
    "        return Node(str(data), parent=parent)\n",
    "\n",
    "    current_node = Node(\"root\") if parent is None else None\n",
    "    for key, value in data.items():\n",
    "        child_node = create_tree(value, parent=current_node)\n",
    "        if current_node is None:\n",
    "            current_node = child_node\n",
    "        else:\n",
    "            current_node.children += (child_node,)\n",
    "\n",
    "    return current_node\n",
    "\n",
    "def visualize_tree(root):\n",
    "    for pre, _, node in RenderTree(root, style=AsciiStyle()):\n",
    "        print(f\"{pre}{node.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    "    \"layer_16_head_24_token_<0x0A>\": {\n",
    "        \"layer_5_head_22_token_W\": {\n",
    "            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "            \"layer_3_head_22_token_<0x0A>\": {\n",
    "                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                \"layer_0_head_15_token_<0x0A>\": {}\n",
    "            }\n",
    "        },\n",
    "        \"layer_9_head_19_token_W\": {\n",
    "            \"layer_8_head_29_token_W\": {\n",
    "                \"layer_7_head_16_token_:\": {\n",
    "                    \"layer_6_head_21_token_<0x0A>\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_3_head_22_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                        }\n",
    "                    },\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                },\n",
    "                \"layer_3_head_16_token_:\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            },\n",
    "            \"layer_5_head_28_token_W\": {\n",
    "                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                \"layer_3_head_22_token_<0x0A>\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"layer_30_head_24_token_<0x0A>\": {\n",
    "        \"layer_16_head_24_token_<0x0A>\": {\n",
    "            \"layer_5_head_22_token_W\": {\n",
    "                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                \"layer_3_head_22_token_<0x0A>\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            },\n",
    "            \"layer_9_head_19_token_W\": {\n",
    "                \"layer_8_head_29_token_W\": {\n",
    "                    \"layer_7_head_16_token_:\": {\n",
    "                        \"layer_6_head_21_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_3_head_22_token_<0x0A>\": {\n",
    "                                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                                \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                            }\n",
    "                        },\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                    },\n",
    "                    \"layer_3_head_16_token_:\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                    }\n",
    "                },\n",
    "                \"layer_5_head_28_token_W\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_3_head_22_token_<0x0A>\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"layer_18_head_30_token_<0x0A>\": {\n",
    "            \"layer_16_head_1_token_ren\": {\n",
    "                \"layer_12_head_17_token_is\": {\n",
    "                    \"layer_1_head_7_token_<0x0A>\": {\n",
    "                        \"layer_0_head_2_token_:\": {},\n",
    "                        \"layer_0_head_31_token_:\": {}\n",
    "                    },\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                },\n",
    "                \"layer_13_head_29_token_is\": {\n",
    "                    \"layer_9_head_26_token_ren\": {\n",
    "                        \"layer_6_head_21_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_3_head_22_token_<0x0A>\": {\n",
    "                                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                                \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                            }\n",
    "                        },\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                    },\n",
    "                    \"layer_3_head_2_token_ren\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                    }\n",
    "                }\n",
    "            },\n",
    "            \"layer_12_head_2_token_ren\": {\n",
    "                \"layer_10_head_9_token_is\": {\n",
    "                    \"layer_6_head_21_token_<0x0A>\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_3_head_22_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                        }\n",
    "                    },\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                },\n",
    "                \"layer_2_head_12_token_is\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "ename": "TreeError",
     "evalue": "Cannot add non-node object None. It is not a subclass of 'NodeMixin'.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTreeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[24], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m root \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_tree\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m visualize_tree(root)\n",
      "Cell \u001b[0;32mIn[22], line 9\u001b[0m, in \u001b[0;36mcreate_tree\u001b[0;34m(data, parent)\u001b[0m\n\u001b[1;32m      7\u001b[0m current_node \u001b[38;5;241m=\u001b[39m Node(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mroot\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m parent \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m----> 9\u001b[0m     child_node \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_tree\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_node\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m current_node \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m     11\u001b[0m         current_node \u001b[38;5;241m=\u001b[39m child_node\n",
      "Cell \u001b[0;32mIn[22], line 9\u001b[0m, in \u001b[0;36mcreate_tree\u001b[0;34m(data, parent)\u001b[0m\n\u001b[1;32m      7\u001b[0m current_node \u001b[38;5;241m=\u001b[39m Node(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mroot\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m parent \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m----> 9\u001b[0m     child_node \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_tree\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_node\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m current_node \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m     11\u001b[0m         current_node \u001b[38;5;241m=\u001b[39m child_node\n",
      "Cell \u001b[0;32mIn[22], line 13\u001b[0m, in \u001b[0;36mcreate_tree\u001b[0;34m(data, parent)\u001b[0m\n\u001b[1;32m     11\u001b[0m         current_node \u001b[38;5;241m=\u001b[39m child_node\n\u001b[1;32m     12\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 13\u001b[0m         \u001b[43mcurrent_node\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mchildren\u001b[49m \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (child_node,)\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m current_node\n",
      "File \u001b[0;32m~/venv/lib/python3.10/site-packages/anytree/node/nodemixin.py:247\u001b[0m, in \u001b[0;36mNodeMixin.children\u001b[0;34m(self, children)\u001b[0m\n\u001b[1;32m    243\u001b[0m \u001b[38;5;129m@children\u001b[39m\u001b[38;5;241m.\u001b[39msetter\n\u001b[1;32m    244\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mchildren\u001b[39m(\u001b[38;5;28mself\u001b[39m, children):\n\u001b[1;32m    245\u001b[0m     \u001b[38;5;66;03m# convert iterable to tuple\u001b[39;00m\n\u001b[1;32m    246\u001b[0m     children \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(children)\n\u001b[0;32m--> 247\u001b[0m     \u001b[43mNodeMixin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__check_children\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchildren\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    248\u001b[0m     \u001b[38;5;66;03m# ATOMIC start\u001b[39;00m\n\u001b[1;32m    249\u001b[0m     old_children \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren\n",
      "File \u001b[0;32m~/venv/lib/python3.10/site-packages/anytree/node/nodemixin.py:235\u001b[0m, in \u001b[0;36mNodeMixin.__check_children\u001b[0;34m(children)\u001b[0m\n\u001b[1;32m    233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(child, (NodeMixin, LightNodeMixin)):\n\u001b[1;32m    234\u001b[0m     msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot add non-node object \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m. It is not a subclass of \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNodeMixin\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (child,)\n\u001b[0;32m--> 235\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m TreeError(msg)\n\u001b[1;32m    236\u001b[0m childid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mid\u001b[39m(child)\n\u001b[1;32m    237\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m childid \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m seen:\n",
      "\u001b[0;31mTreeError\u001b[0m: Cannot add non-node object None. It is not a subclass of 'NodeMixin'."
     ]
    }
   ],
   "source": [
    "root = create_tree(data)\n",
    "visualize_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'NoneType' object has no attribute 'name'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[26], line 139\u001b[0m\n\u001b[1;32m     22\u001b[0m data \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m     23\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayer_16_head_24_token_<0x0A>\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\n\u001b[1;32m     24\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayer_5_head_22_token_W\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    135\u001b[0m     }\n\u001b[1;32m    136\u001b[0m }\n\u001b[1;32m    138\u001b[0m root \u001b[38;5;241m=\u001b[39m create_tree(data)\n\u001b[0;32m--> 139\u001b[0m \u001b[43mvisualize_tree\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[26], line 20\u001b[0m, in \u001b[0;36mvisualize_tree\u001b[0;34m(root)\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvisualize_tree\u001b[39m(root):\n\u001b[1;32m     19\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m pre, _, node \u001b[38;5;129;01min\u001b[39;00m RenderTree(root, style\u001b[38;5;241m=\u001b[39mAsciiStyle()):\n\u001b[0;32m---> 20\u001b[0m         \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpre\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'name'"
     ]
    }
   ],
   "source": [
    "from anytree import Node, RenderTree, AsciiStyle\n",
    "\n",
    "def create_tree(data, parent=None):\n",
    "    if not isinstance(data, dict):\n",
    "        return Node(str(data), parent=parent)\n",
    "\n",
    "    current_node = Node(\"root\") if parent is None else None\n",
    "    for key, value in data.items():\n",
    "        child_node = create_tree(value, parent=current_node)\n",
    "        if current_node is None:\n",
    "            current_node = child_node\n",
    "        else:\n",
    "            current_node.children += (child_node,)\n",
    "\n",
    "    return current_node\n",
    "\n",
    "def visualize_tree(root):\n",
    "    for pre, _, node in RenderTree(root, style=AsciiStyle()):\n",
    "        print(f\"{pre}{node.name}\")\n",
    "\n",
    "data = {\n",
    "    \"layer_16_head_24_token_<0x0A>\": {\n",
    "        \"layer_5_head_22_token_W\": {\n",
    "            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "            \"layer_3_head_22_token_<0x0A>\": {\n",
    "                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                \"layer_0_head_15_token_<0x0A>\": {}\n",
    "            }\n",
    "        },\n",
    "        \"layer_9_head_19_token_W\": {\n",
    "            \"layer_8_head_29_token_W\": {\n",
    "                \"layer_7_head_16_token_:\": {\n",
    "                    \"layer_6_head_21_token_<0x0A>\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_3_head_22_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                        }\n",
    "                    },\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                },\n",
    "                \"layer_3_head_16_token_:\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            },\n",
    "            \"layer_5_head_28_token_W\": {\n",
    "                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                \"layer_3_head_22_token_<0x0A>\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"layer_30_head_24_token_<0x0A>\": {\n",
    "        \"layer_16_head_24_token_<0x0A>\": {\n",
    "            \"layer_5_head_22_token_W\": {\n",
    "                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                \"layer_3_head_22_token_<0x0A>\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            },\n",
    "            \"layer_9_head_19_token_W\": {\n",
    "                \"layer_8_head_29_token_W\": {\n",
    "                    \"layer_7_head_16_token_:\": {\n",
    "                        \"layer_6_head_21_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_3_head_22_token_<0x0A>\": {\n",
    "                                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                                \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                            }\n",
    "                        },\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                    },\n",
    "                    \"layer_3_head_16_token_:\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                    }\n",
    "                },\n",
    "                \"layer_5_head_28_token_W\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_3_head_22_token_<0x0A>\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"layer_18_head_30_token_<0x0A>\": {\n",
    "            \"layer_16_head_1_token_ren\": {\n",
    "                \"layer_12_head_17_token_is\": {\n",
    "                    \"layer_1_head_7_token_<0x0A>\": {\n",
    "                        \"layer_0_head_2_token_:\": {},\n",
    "                        \"layer_0_head_31_token_:\": {}\n",
    "                    },\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                },\n",
    "                \"layer_13_head_29_token_is\": {\n",
    "                    \"layer_9_head_26_token_ren\": {\n",
    "                        \"layer_6_head_21_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_3_head_22_token_<0x0A>\": {\n",
    "                                \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                                \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                            }\n",
    "                        },\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                    },\n",
    "                    \"layer_3_head_2_token_ren\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                    }\n",
    "                }\n",
    "            },\n",
    "            \"layer_12_head_2_token_ren\": {\n",
    "                \"layer_10_head_9_token_is\": {\n",
    "                    \"layer_6_head_21_token_<0x0A>\": {\n",
    "                        \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                        \"layer_3_head_22_token_<0x0A>\": {\n",
    "                            \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                            \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                        }\n",
    "                    },\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {}\n",
    "                },\n",
    "                \"layer_2_head_12_token_is\": {\n",
    "                    \"layer_0_head_22_token_<0x0A>\": {},\n",
    "                    \"layer_0_head_15_token_<0x0A>\": {}\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "root = create_tree(data)\n",
    "visualize_tree(root)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
