{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "db0ea784e3fa45ac8c7b659ddbae8821",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Qwen2ForCausalLM(\n",
       "  (model): Qwen2Model(\n",
       "    (embed_tokens): Embedding(152064, 3584)\n",
       "    (layers): ModuleList(\n",
       "      (0-27): 28 x Qwen2DecoderLayer(\n",
       "        (self_attn): Qwen2SdpaAttention(\n",
       "          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)\n",
       "          (k_proj): Linear(in_features=3584, out_features=512, bias=True)\n",
       "          (v_proj): Linear(in_features=3584, out_features=512, bias=True)\n",
       "          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)\n",
       "          (rotary_emb): Qwen2RotaryEmbedding()\n",
       "        )\n",
       "        (mlp): Qwen2MLP(\n",
       "          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)\n",
       "          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)\n",
       "          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): Qwen2RMSNorm()\n",
       "        (post_attention_layernorm): Qwen2RMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): Qwen2RMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=3584, out_features=152064, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "model_name = \"Qwen/Qwen2.5-7B-Instruct\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).cuda()\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'diabetes': 'hypertension', 'negative': 'positive', 'March': 'July', '2023': '2024', 'metformin': 'insulin'}\n",
      "\n",
      "Template: The patient was screened {negative} for type 2 {diabetes} in {March} {_num_2023}, and is currently taking {metformin} twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "\n",
      "==== Generated text variants ====\n",
      "0: The patient was screened negative for type 2 diabetes in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "1: The patient was screened negative for type 2 diabetes in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "2: The patient was screened negative for type 2 diabetes in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "3: The patient was screened negative for type 2 diabetes in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "4: The patient was screened negative for type 2 diabetes in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "5: The patient was screened negative for type 2 diabetes in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "6: The patient was screened negative for type 2 diabetes in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "7: The patient was screened negative for type 2 diabetes in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "8: The patient was screened positive for type 2 diabetes in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "9: The patient was screened positive for type 2 diabetes in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "10: The patient was screened positive for type 2 diabetes in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "11: The patient was screened positive for type 2 diabetes in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "12: The patient was screened positive for type 2 diabetes in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "13: The patient was screened positive for type 2 diabetes in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "14: The patient was screened positive for type 2 diabetes in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "15: The patient was screened positive for type 2 diabetes in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "16: The patient was screened negative for type 2 hypertension in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "17: The patient was screened negative for type 2 hypertension in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "18: The patient was screened negative for type 2 hypertension in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "19: The patient was screened negative for type 2 hypertension in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "20: The patient was screened negative for type 2 hypertension in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "21: The patient was screened negative for type 2 hypertension in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "22: The patient was screened negative for type 2 hypertension in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "23: The patient was screened negative for type 2 hypertension in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "24: The patient was screened positive for type 2 hypertension in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "25: The patient was screened positive for type 2 hypertension in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "26: The patient was screened positive for type 2 hypertension in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "27: The patient was screened positive for type 2 hypertension in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "28: The patient was screened positive for type 2 hypertension in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "29: The patient was screened positive for type 2 hypertension in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "30: The patient was screened positive for type 2 hypertension in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "31: The patient was screened positive for type 2 hypertension in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "\n",
      "==== Concept pairs (positive/negative indices) ====\n",
      "diabetes/hypertension: [(0, 16), (1, 17), (2, 18), (3, 19), (4, 20), (5, 21), (6, 22), (7, 23), (8, 24), (9, 25), (10, 26), (11, 27), (12, 28), (13, 29), (14, 30), (15, 31)]\n",
      "negative/positive: [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)]\n",
      "March/July: [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)]\n",
      "2023/2024: [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)]\n",
      "metformin/insulin: [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)]\n"
     ]
    }
   ],
   "source": [
    "import itertools\n",
    "import math\n",
    "\n",
    "def tokenize(text):\n",
    "    return text.replace(\",\", \"\").replace(\".\", \"\").split()\n",
    "\n",
    "def build_concept_dict(tokens, selections):\n",
    "    \"\"\"\n",
    "    tokens: list of tokenized words\n",
    "    selections: list of (index, opposite_word)\n",
    "    \"\"\"\n",
    "    concept_dict = {}\n",
    "    for idx, opposite in selections:\n",
    "        word = tokens[idx]\n",
    "        concept_dict[word] = opposite\n",
    "    return concept_dict\n",
    "\n",
    "def sanitize_key(key):\n",
    "    \"\"\"Make numeric keys compatible with Python format placeholders\"\"\"\n",
    "    return f\"_num_{key}\" if key.isdigit() else key\n",
    "\n",
    "# Example base text\n",
    "base_text = (\n",
    "    \"The patient was screened negative for type 2 diabetes in March 2023, \"\n",
    "    \"and is currently taking metformin twice daily.\"\n",
    "    \"Based on these findings, provide the most likely diagnosis:\"\n",
    ")\n",
    "\n",
    "tokens = tokenize(base_text)\n",
    "\n",
    "selections = [\n",
    "    (8, \"hypertension\"),  # Disease: diabetes → hypertension\n",
    "    (4, \"positive\"),      # Status: negative → positive\n",
    "    (10, \"July\"),         # Time: March → July\n",
    "    (11, \"2024\"),         # Year: 2023 → 2024\n",
    "    (16, \"insulin\")       # Medication: metformin → insulin\n",
    "]\n",
    "threshold = 0.9\n",
    "\n",
    "concept_dict = build_concept_dict(tokens, selections)\n",
    "print(concept_dict)\n",
    "\n",
    "key_map = {k: sanitize_key(k) for k in concept_dict.keys()}\n",
    "\n",
    "# ✅ Dynamically build a template\n",
    "template = base_text\n",
    "for k in concept_dict.keys():\n",
    "    template = template.replace(k, \"{\" + key_map[k] + \"}\")\n",
    "\n",
    "print(\"\\nTemplate:\", template)\n",
    "\n",
    "# ✅ Build all possible combinations\n",
    "concept_keys = list(concept_dict.keys())\n",
    "concept_values = [(k, concept_dict[k]) for k in concept_keys]\n",
    "all_combinations = list(itertools.product(*[(pos, neg) for pos, neg in concept_values]))\n",
    "\n",
    "# ✅ Generate all text variants\n",
    "text = []\n",
    "for combo in all_combinations:\n",
    "    fill_dict = {\n",
    "        key_map[k]: v for k, v in zip(concept_keys, combo)\n",
    "    }\n",
    "    text.append(template.format(**fill_dict))\n",
    "\n",
    "# ✅ Build concept_pairs (mapping each positive example to its negative counterpart)\n",
    "concept_pairs = {}\n",
    "for concept in concept_keys:\n",
    "    pos, neg = concept, concept_dict[concept]\n",
    "    pairs = []\n",
    "    concept_idx = concept_keys.index(concept)\n",
    "    for i, combo in enumerate(all_combinations):\n",
    "        if combo[concept_idx] == pos:\n",
    "            neg_combo = list(combo)\n",
    "            neg_combo[concept_idx] = neg\n",
    "            neg_index = all_combinations.index(tuple(neg_combo))\n",
    "            pairs.append((i, neg_index))\n",
    "    concept_pairs[f\"{pos}/{neg}\"] = pairs\n",
    "\n",
    "# ✅ Output\n",
    "print(\"\\n==== Generated text variants ====\")\n",
    "for idx, t in enumerate(text):\n",
    "    print(f\"{idx}: {t}\")\n",
    "\n",
    "print(\"\\n==== Concept pairs (positive/negative indices) ====\")\n",
    "for key, pairs in concept_pairs.items():\n",
    "    print(f\"{key}: {pairs}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.7` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/opt/conda/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n",
      "/opt/conda/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:509: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `20` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
      "  warnings.warn(\n",
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[   785,   8720,    572,  ...,   4363,  22982,     25],\n",
      "        [   785,   8720,    572,  ...,     25, 151643, 151643],\n",
      "        [   785,   8720,    572,  ...,   4363,  22982,     25],\n",
      "        ...,\n",
      "        [   785,   8720,    572,  ...,     25, 151643, 151643],\n",
      "        [   785,   8720,    572,  ...,   4363,  22982,     25],\n",
      "        [   785,   8720,    572,  ...,     25, 151643, 151643]],\n",
      "       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],\n",
      "        [1, 1, 1,  ..., 1, 0, 0],\n",
      "        [1, 1, 1,  ..., 1, 1, 1],\n",
      "        ...,\n",
      "        [1, 1, 1,  ..., 1, 0, 0],\n",
      "        [1, 1, 1,  ..., 1, 1, 1],\n",
      "        [1, 1, 1,  ..., 1, 0, 0]], device='cuda:0')}\n",
      "[Input 0] The patient was screened negative for type 2 diabetes in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 0] The patient was screened negative for type 2 diabetes in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes\n",
      "B: Type 2 diabetes\n",
      "C: Gestational diabetes\n",
      "D: Prediabetes\n",
      "\n",
      "Given the information provided:\n",
      "\n",
      "- The patient was screened negative for type 2 diabetes in March 2023.\n",
      "\n",
      "\n",
      "[Input 1] The patient was screened negative for type 2 diabetes in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 1] The patient was screened negative for type 2 diabetes in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? Given the information provided, it's not possible to definitively diagnose whether the patient has type 1 or type 2 diabetes based solely on the screening result from March 20\n",
      "\n",
      "[Input 2] The patient was screened negative for type 2 diabetes in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 2] The patient was screened negative for type 2 diabetes in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes\n",
      "B: Type 2 diabetes\n",
      "C: Gestational diabetes\n",
      "D: Prediabetes\n",
      "\n",
      "Given the information provided:\n",
      "\n",
      "- The patient was screened negative for type 2 diabetes in March 2024.\n",
      "\n",
      "\n",
      "[Input 3] The patient was screened negative for type 2 diabetes in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 3] The patient was screened negative for type 2 diabetes in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: for the patient.\n",
      "Based on the information provided, it seems that the patient is already being treated with insulin for a condition that has not been explicitly stated. Since the patient was screened negative for type 2 diabetes in March 2024,\n",
      "\n",
      "[Input 4] The patient was screened negative for type 2 diabetes in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 4] The patient was screened negative for type 2 diabetes in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes\n",
      "B: Type 2 diabetes\n",
      "C: Gestational diabetes\n",
      "D: Prediabetes\n",
      "\n",
      "Given the information provided:\n",
      "\n",
      "- The patient was screened negative for type 2 diabetes in July 2023.\n",
      "\n",
      "\n",
      "[Input 5] The patient was screened negative for type 2 diabetes in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 5] The patient was screened negative for type 2 diabetes in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: for the patient, and suggest a management plan.\n",
      "Based on the information provided, it seems that the patient has been diagnosed with type 1 diabetes, given their need for insulin therapy. Type 2 diabetes typically does not require insulin unless it progresses to\n",
      "\n",
      "[Input 6] The patient was screened negative for type 2 diabetes in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 6] The patient was screened negative for type 2 diabetes in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes\n",
      "B: Type 2 diabetes\n",
      "C: Gestational diabetes\n",
      "D: Prediabetes\n",
      "\n",
      "Given the information provided:\n",
      "\n",
      "- The patient was screened negative for type 2 diabetes in July 2024.\n",
      "\n",
      "\n",
      "[Input 7] The patient was screened negative for type 2 diabetes in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 7] The patient was screened negative for type 2 diabetes in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: for the patient.\n",
      "Based on the information provided, it seems that the patient is already being treated with insulin for a condition that has not been explicitly stated. Since the patient was screened negative for type 2 diabetes in July 2024,\n",
      "\n",
      "[Input 8] The patient was screened positive for type 2 diabetes in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 8] The patient was screened positive for type 2 diabetes in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Gestational diabetes D: Prediabetes\n",
      "Based on the information provided, the most likely diagnosis is:\n",
      "\n",
      "B: Type 2 diabetes\n",
      "\n",
      "Here's the reasoning:\n",
      "- The patient\n",
      "\n",
      "[Input 9] The patient was screened positive for type 2 diabetes in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 9] The patient was screened positive for type 2 diabetes in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? Provide a brief explanation.\n",
      "Based on the information provided, the patient was screened positive for type 2 diabetes. The key points supporting this are:\n",
      "\n",
      "1. **Insulin Use**:\n",
      "\n",
      "[Input 10] The patient was screened positive for type 2 diabetes in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 10] The patient was screened positive for type 2 diabetes in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: Type 2 Diabetes Mellitus.\n",
      "The patient has been diagnosed with Type 2 Diabetes Mellitus, as indicated by the screening positivity and the current treatment with metformin. Metformin is a common first-line medication used in the management of Type\n",
      "\n",
      "[Input 11] The patient was screened positive for type 2 diabetes in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 11] The patient was screened positive for type 2 diabetes in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? Provide a brief explanation.\n",
      "Based on the information provided, the patient was screened positive for type 2 diabetes. The key points supporting this are:\n",
      "\n",
      "1. **Insulin Use**:\n",
      "\n",
      "[Input 12] The patient was screened positive for type 2 diabetes in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 12] The patient was screened positive for type 2 diabetes in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: Type 2 Diabetes Mellitus.\n",
      "The patient has been diagnosed with Type 2 Diabetes Mellitus, as indicated by their screening positivity and current treatment with metformin. Metformin is a common first-line medication used to manage blood glucose levels in\n",
      "\n",
      "[Input 13] The patient was screened positive for type 2 diabetes in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 13] The patient was screened positive for type 2 diabetes in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? Provide a brief explanation.\n",
      "Based on the information provided, the patient was screened positive for type 2 diabetes. The key points supporting this are:\n",
      "\n",
      "1. **Insulin Use**:\n",
      "\n",
      "[Input 14] The patient was screened positive for type 2 diabetes in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 14] The patient was screened positive for type 2 diabetes in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: Type 2 Diabetes Mellitus.\n",
      "The patient has been diagnosed with Type 2 Diabetes Mellitus, as indicated by the screening positivity and the current treatment with metformin. Metformin is a common first-line medication used to manage blood glucose levels\n",
      "\n",
      "[Input 15] The patient was screened positive for type 2 diabetes in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 15] The patient was screened positive for type 2 diabetes in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? Provide a brief explanation.\n",
      "Based on the information provided, the patient was screened positive for type 2 diabetes. The key points supporting this are:\n",
      "\n",
      "1. **Screening Positive**:\n",
      "\n",
      "[Input 16] The patient was screened negative for type 2 hypertension in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 16] The patient was screened negative for type 2 hypertension in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Hypertension D: Metformin toxicity\n",
      "Based on the information provided, the most likely diagnosis is:\n",
      "\n",
      "B: Type 2 diabetes\n",
      "\n",
      "Here's the reasoning:\n",
      "- The\n",
      "\n",
      "[Input 17] The patient was screened negative for type 2 hypertension in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 17] The patient was screened negative for type 2 hypertension in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? To determine whether the patient was diagnosed with type 1 or type 2 diabetes based on the provided information, let's analyze the key points:\n",
      "\n",
      "1. **Screening Negative for\n",
      "\n",
      "[Input 18] The patient was screened negative for type 2 hypertension in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 18] The patient was screened negative for type 2 hypertension in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Hypertension D: Metformin toxicity\n",
      "Based on the information provided:\n",
      "\n",
      "- The patient was screened negative for type 2 hypertension in March 2024.\n",
      "-\n",
      "\n",
      "[Input 19] The patient was screened negative for type 2 hypertension in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 19] The patient was screened negative for type 2 hypertension in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: for the patient.\n",
      "Based on the information provided, it seems that the patient has been screened negative for type 2 hypertension in March 2024. However, the patient is currently taking insulin twice daily. This suggests that the most likely diagnosis\n",
      "\n",
      "[Input 20] The patient was screened negative for type 2 hypertension in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 20] The patient was screened negative for type 2 hypertension in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Hypertension D: Metformin toxicity\n",
      "Based on the information provided, the most likely diagnosis is:\n",
      "\n",
      "B: Type 2 diabetes\n",
      "\n",
      "Here's the reasoning:\n",
      "- The\n",
      "\n",
      "[Input 21] The patient was screened negative for type 2 hypertension in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 21] The patient was screened negative for type 2 hypertension in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? To determine whether the patient was diagnosed with type 1 or type 2 diabetes based on the provided information, let's analyze the key points:\n",
      "\n",
      "1. **Screening Negative for\n",
      "\n",
      "[Input 22] The patient was screened negative for type 2 hypertension in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 22] The patient was screened negative for type 2 hypertension in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Hypertension D: Metformin toxicity\n",
      "Based on the information provided:\n",
      "\n",
      "- The patient was screened negative for type 2 hypertension in July 2024.\n",
      "-\n",
      "\n",
      "[Input 23] The patient was screened negative for type 2 hypertension in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 23] The patient was screened negative for type 2 hypertension in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: for the patient.\n",
      "Based on the information provided, it seems there might be some confusion or missing context. The patient was screened negative for type 2 hypertension in July 2024, but they are currently taking insulin twice daily. Typically,\n",
      "\n",
      "[Input 24] The patient was screened positive for type 2 hypertension in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 24] The patient was screened positive for type 2 hypertension in March 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Hypertension D: Metformin toxicity\n",
      "Based on the information provided, the most likely diagnosis is:\n",
      "\n",
      "B: Type 2 diabetes\n",
      "\n",
      "Here's the reasoning:\n",
      "- The\n",
      "\n",
      "[Input 25] The patient was screened positive for type 2 hypertension in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 25] The patient was screened positive for type 2 hypertension in March 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? The information provided does not directly indicate whether the patient has been diagnosed with type 1 or type 2 diabetes. However, there are some key points to consider:\n",
      "\n",
      "1. **\n",
      "\n",
      "[Input 26] The patient was screened positive for type 2 hypertension in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 26] The patient was screened positive for type 2 hypertension in March 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: Type 2 diabetes mellitus.\n",
      "The information provided suggests that the patient has been diagnosed with type 2 hypertension (high blood pressure) and is currently taking metformin, which is a medication commonly used to treat type 2 diabetes. However,\n",
      "\n",
      "[Input 27] The patient was screened positive for type 2 hypertension in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 27] The patient was screened positive for type 2 hypertension in March 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? The information provided does not directly indicate whether the patient has been diagnosed with type 1 or type 2 diabetes. However, there are some key points to consider:\n",
      "\n",
      "1. **\n",
      "\n",
      "[Input 28] The patient was screened positive for type 2 hypertension in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 28] The patient was screened positive for type 2 hypertension in July 2023, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: Type 2 diabetes mellitus.\n",
      "The information provided suggests that the patient has been diagnosed with type 2 hypertension (high blood pressure) and is currently taking metformin, which is a medication commonly used to treat type 2 diabetes. However,\n",
      "\n",
      "[Input 29] The patient was screened positive for type 2 hypertension in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 29] The patient was screened positive for type 2 hypertension in July 2023, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? The information provided does not directly indicate whether the patient has been diagnosed with type 1 or type 2 diabetes. However, there are a few points to consider:\n",
      "\n",
      "1. **\n",
      "\n",
      "[Input 30] The patient was screened positive for type 2 hypertension in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 30] The patient was screened positive for type 2 hypertension in July 2024, and is currently taking metformin twice daily.Based on these findings, provide the most likely diagnosis: A: Type 1 diabetes B: Type 2 diabetes C: Hypertension D: Metformin toxicity\n",
      "Based on the information provided:\n",
      "\n",
      "- The patient was screened positive for type 2 hypertension in July 2024.\n",
      "-\n",
      "\n",
      "[Input 31] The patient was screened positive for type 2 hypertension in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis:\n",
      "[Output 31] The patient was screened positive for type 2 hypertension in July 2024, and is currently taking insulin twice daily.Based on these findings, provide the most likely diagnosis: was the patient diagnosed with type 1 or type 2 diabetes? The information provided does not directly indicate whether the patient has been diagnosed with type 1 or type 2 diabetes. However, there are some key points to consider:\n",
      "\n",
      "1. **\n",
      "\n"
     ]
    }
   ],
   "source": [
    "inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True).to(\"cuda\")\n",
    "mask = inputs[\"attention_mask\"]\n",
    "last_idx = mask.sum(dim=1) - 1\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(\n",
    "        **inputs,\n",
    "        max_new_tokens=50,\n",
    "        do_sample=False  # 设置为 True 可以进行采样\n",
    "    )\n",
    "\n",
    "decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for i, out in enumerate(decoded):\n",
    "    print(f\"[Input {i}] {text[i]}\")\n",
    "    print(f\"[Output {i}] {out}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "attn_inputs = [None for _ in range(len(model.model.layers))]\n",
    "mlp_inputs = [None for _ in range(len(model.model.layers))]\n",
    "hooks = []\n",
    "\n",
    "def make_attn_input_hook(i):\n",
    "    def hook(module, inp, output):\n",
    "        attn_inputs[i] = inp[0].detach().cpu()\n",
    "    return hook\n",
    "\n",
    "def make_mlp_input_hook(i):\n",
    "    def hook(module, inp, output):\n",
    "        mlp_inputs[i] = inp[0].detach().cpu()\n",
    "    return hook\n",
    "\n",
    "for i, layer in enumerate(model.model.layers):\n",
    "    hooks.append(layer.self_attn.v_proj.register_forward_hook(make_attn_input_hook(i)))\n",
    "    hooks.append(layer.mlp.down_proj.register_forward_hook(make_mlp_input_hook(i)))\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "\n",
    "for h in hooks:\n",
    "    h.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Top-10 logits cosine similarity: 0.9108\n",
      "Sample 0: next predicted token -> ' A'\n",
      "Sample 1: next predicted token -> ' A'\n",
      "Sample 2: next predicted token -> ' A'\n",
      "Sample 3: next predicted token -> ' A'\n",
      "Sample 4: next predicted token -> ' A'\n",
      "Sample 5: next predicted token -> ' A'\n",
      "Sample 6: next predicted token -> ' A'\n",
      "Sample 7: next predicted token -> ' A'\n",
      "Sample 8: next predicted token -> ' A'\n",
      "Sample 9: next predicted token -> ' A'\n",
      "Sample 10: next predicted token -> ' Type'\n",
      "Sample 11: next predicted token -> ' Type'\n",
      "Sample 12: next predicted token -> ' Type'\n",
      "Sample 13: next predicted token -> ' Type'\n",
      "Sample 14: next predicted token -> ' Type'\n",
      "Sample 15: next predicted token -> ' Type'\n",
      "Sample 16: next predicted token -> ' A'\n",
      "Sample 17: next predicted token -> ' A'\n",
      "Sample 18: next predicted token -> ' A'\n",
      "Sample 19: next predicted token -> ' A'\n",
      "Sample 20: next predicted token -> ' A'\n",
      "Sample 21: next predicted token -> ' A'\n",
      "Sample 22: next predicted token -> ' A'\n",
      "Sample 23: next predicted token -> ' A'\n",
      "Sample 24: next predicted token -> ' A'\n",
      "Sample 25: next predicted token -> ' A'\n",
      "Sample 26: next predicted token -> ' Type'\n",
      "Sample 27: next predicted token -> ' A'\n",
      "Sample 28: next predicted token -> ' Type'\n",
      "Sample 29: next predicted token -> ' Type'\n",
      "Sample 30: next predicted token -> ' A'\n",
      "Sample 31: next predicted token -> ' A'\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "topk = 10\n",
    "logits = outputs.logits\n",
    "last_indices = inputs[\"attention_mask\"].sum(dim=1) - 1  # shape: (batch_size,)\n",
    "batch_size = logits.size(0)\n",
    "\n",
    "last_logits = logits[torch.arange(batch_size), last_indices]  # (batch_size, vocab_size)\n",
    "\n",
    "# Top-k logits\n",
    "top0 = torch.topk(last_logits[0], k=topk)\n",
    "top1 = torch.topk(last_logits[1], k=topk)\n",
    "ids0, vals0 = top0.indices.tolist(), top0.values.tolist()\n",
    "ids1, vals1 = top1.indices.tolist(), top1.values.tolist()\n",
    "\n",
    "vocab_size = last_logits.size(-1)\n",
    "vec0 = torch.zeros(vocab_size)\n",
    "vec1 = torch.zeros(vocab_size)\n",
    "\n",
    "for i, v in zip(ids0, vals0):\n",
    "    vec0[i] = v\n",
    "for i, v in zip(ids1, vals1):\n",
    "    vec1[i] = v\n",
    "\n",
    "# Top-k sparse logit vector cosine similarity\n",
    "topk_cos_sim = F.cosine_similarity(vec0, vec1, dim=0).item()\n",
    "print(f\"\\nTop-{topk} logits cosine similarity: {topk_cos_sim:.4f}\")\n",
    "pred_token_ids = torch.argmax(last_logits, dim=-1)  # shape: (batch_size,)\n",
    "pred_tokens = tokenizer.batch_decode(pred_token_ids)\n",
    "for i, token in enumerate(pred_tokens):\n",
    "    print(f\"Sample {i}: next predicted token -> {repr(token)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk = 5\n",
    "\n",
    "def mask_top_k_spectral(s_rep, k=5):\n",
    "    topk = torch.topk(s_rep, k=k, dim=-1)\n",
    "    mask = torch.zeros_like(s_rep)\n",
    "    mask.scatter_(dim=-1, index=topk.indices, src=topk.values)\n",
    "    return mask\n",
    "\n",
    "def normalize(x: float) -> float:\n",
    "    # Simple normalization: clamp to [0,1]\n",
    "    return max(0.0, min(1.0))\n",
    "\n",
    "attn_spectral = []\n",
    "\n",
    "for i, layer in enumerate(model.model.layers):\n",
    "    if attn_inputs[i] is not None:\n",
    "        idx = last_idx.view(-1, 1, 1).expand(-1, 1, attn_inputs[i].size(-1)).to(\"cuda\")\n",
    "        x = attn_inputs[i].to(\"cuda\").gather(dim=1, index=idx).squeeze(1)  # (batch_size, hidden_dim)\n",
    "        W = layer.self_attn.v_proj.weight.T.detach()  # (in_dim, out_dim)\n",
    "        U, S, Vh = torch.linalg.svd(W, full_matrices=False)\n",
    "        s_rep = (x @ U) * S  # (batch_size, r)\n",
    "\n",
    "        s_rep = mask_top_k_spectral(s_rep, k=topk)\n",
    "        attn_spectral.append(s_rep.cpu().numpy())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--- Start computing the separation layer for each concept (threshold < 0.9) ---\n",
      "Layer 00 | Concept 'diabetes/hypertension': Average similarity = 1.0000\n",
      "Layer 01 | Concept 'diabetes/hypertension': Average similarity = 1.0000\n",
      "Layer 02 | Concept 'diabetes/hypertension': Average similarity = 1.0000\n",
      "Layer 03 | Concept 'diabetes/hypertension': Average similarity = 0.9624\n",
      "Layer 04 | Concept 'diabetes/hypertension': Average similarity = 0.8736\n",
      "==> Concept 'diabetes/hypertension' separates at layer 4!\n",
      "\n",
      "Layer 00 | Concept 'negative/positive': Average similarity = 1.0000\n",
      "Layer 01 | Concept 'negative/positive': Average similarity = 1.0000\n",
      "Layer 02 | Concept 'negative/positive': Average similarity = 0.9999\n",
      "Layer 03 | Concept 'negative/positive': Average similarity = 0.9627\n",
      "Layer 04 | Concept 'negative/positive': Average similarity = 0.9559\n",
      "Layer 05 | Concept 'negative/positive': Average similarity = 1.0000\n",
      "Layer 06 | Concept 'negative/positive': Average similarity = 0.9775\n",
      "Layer 07 | Concept 'negative/positive': Average similarity = 0.9999\n",
      "Layer 08 | Concept 'negative/positive': Average similarity = 0.9616\n",
      "Layer 09 | Concept 'negative/positive': Average similarity = 0.9400\n",
      "==> Concept 'negative/positive' separates at layer 9!\n",
      "\n",
      "Layer 00 | Concept 'March/July': Average similarity = 1.0000\n",
      "Layer 01 | Concept 'March/July': Average similarity = 1.0000\n",
      "Layer 02 | Concept 'March/July': Average similarity = 1.0000\n",
      "Layer 03 | Concept 'March/July': Average similarity = 1.0000\n",
      "Layer 04 | Concept 'March/July': Average similarity = 0.9737\n",
      "Layer 05 | Concept 'March/July': Average similarity = 1.0000\n",
      "Layer 06 | Concept 'March/July': Average similarity = 0.9775\n",
      "Layer 07 | Concept 'March/July': Average similarity = 1.0000\n",
      "Layer 08 | Concept 'March/July': Average similarity = 0.9921\n",
      "Layer 09 | Concept 'March/July': Average similarity = 0.9833\n",
      "Layer 10 | Concept 'March/July': Average similarity = 0.8854\n",
      "==> Concept 'March/July' separates at layer 10!\n",
      "\n",
      "Layer 00 | Concept '2023/2024': Average similarity = 1.0000\n",
      "Layer 01 | Concept '2023/2024': Average similarity = 1.0000\n",
      "Layer 02 | Concept '2023/2024': Average similarity = 1.0000\n",
      "Layer 03 | Concept '2023/2024': Average similarity = 0.9816\n",
      "Layer 04 | Concept '2023/2024': Average similarity = 0.9644\n",
      "Layer 05 | Concept '2023/2024': Average similarity = 0.9999\n",
      "Layer 06 | Concept '2023/2024': Average similarity = 0.9552\n",
      "Layer 07 | Concept '2023/2024': Average similarity = 0.9998\n",
      "Layer 08 | Concept '2023/2024': Average similarity = 0.9463\n",
      "Layer 09 | Concept '2023/2024': Average similarity = 0.9742\n",
      "Layer 10 | Concept '2023/2024': Average similarity = 0.8372\n",
      "==> Concept '2023/2024' separates at layer 10!\n",
      "\n",
      "Layer 00 | Concept 'metformin/insulin': Average similarity = 1.0000\n",
      "Layer 01 | Concept 'metformin/insulin': Average similarity = 1.0000\n",
      "Layer 02 | Concept 'metformin/insulin': Average similarity = 0.9998\n",
      "Layer 03 | Concept 'metformin/insulin': Average similarity = 0.8709\n",
      "==> Concept 'metformin/insulin' separates at layer 3!\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import graphviz\n",
    "\n",
    "def cosine_similarity(v1, v2):\n",
    "    \"\"\"Compute cosine similarity between two vectors\"\"\"\n",
    "    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))\n",
    "\n",
    "# --- Define concepts and their index pairs ---\n",
    "\n",
    "separation_results = {}\n",
    "\n",
    "print(\"\\n--- Start computing the separation layer for each concept (threshold < 0.9) ---\")\n",
    "\n",
    "# Iterate through each concept\n",
    "for concept_name, pairs in concept_pairs.items():\n",
    "    found_separation = False\n",
    "    # Iterate through each layer\n",
    "    for layer_idx, layer_data in enumerate(attn_spectral):\n",
    "        similarities = []\n",
    "        # Compute cosine similarity for all pairs in this layer\n",
    "        for i, j in pairs:\n",
    "            sim = cosine_similarity(layer_data[i], layer_data[j])\n",
    "            similarities.append(sim)\n",
    "        \n",
    "        # Compute the average similarity\n",
    "        avg_similarity = np.mean(similarities)\n",
    "        \n",
    "        # Print similarities for the first few layers for inspection\n",
    "        if layer_idx < 15:\n",
    "             print(f\"Layer {layer_idx:02d} | Concept '{concept_name}': Average similarity = {avg_similarity:.4f}\")\n",
    "\n",
    "        # Check if it drops below the threshold for the first time\n",
    "        if avg_similarity < normalize(threshold):\n",
    "            separation_results[concept_name] = layer_idx\n",
    "            print(f\"==> Concept '{concept_name}' separates at layer {layer_idx}!\\n\")\n",
    "            found_separation = True\n",
    "            break\n",
    "            \n",
    "    if not found_separation:\n",
    "        separation_results[concept_name] = float('inf')  # Indicates never separated\n",
    "        print(f\"==> Concept '{concept_name}' never separates in any layer.\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"379pt\" height=\"317pt\" viewBox=\"0.00 0.00 378.64 316.64\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4.32 312.32)\">\n",
       "<title>ConceptSeparationTree</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4.32,4.32 -4.32,-312.32 374.32,-312.32 374.32,4.32 -4.32,4.32\"/>\n",
       "<!-- All Concepts (Layer 0) -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>All Concepts (Layer 0)</title>\n",
       "<polygon fill=\"#eadbc8\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"170,-36 50,-36 50,0 170,0 170,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"110\" y=\"-15.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">All Concepts (Layer 0)</text>\n",
       "</g>\n",
       "<!-- metformin/insulin -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>metformin/insulin</title>\n",
       "<polygon fill=\"#e6f4ea\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"100,-104 0,-104 0,-68 100,-68 100,-104\"/>\n",
       "<text text-anchor=\"middle\" x=\"50\" y=\"-89\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">metformin/insulin</text>\n",
       "<text text-anchor=\"middle\" x=\"50\" y=\"-78\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">(Layer 3)</text>\n",
       "</g>\n",
       "<!-- All Concepts (Layer 0)&#45;&gt;metformin/insulin -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>All Concepts (Layer 0)-&gt;metformin/insulin</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M94.24,-36.34C85.49,-45.97 74.63,-57.9 65.87,-67.55\"/>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 3) -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>Remaining (≤ Layer 3)</title>\n",
       "<g id=\"a_node3\"><a xlink:title=\"2023/2024\n",
       "March/July\n",
       "diabetes/hypertension\n",
       "negative/positive\">\n",
       "<polygon fill=\"#eadbc8\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"218,-104 122,-104 122,-68 218,-68 218,-104\"/>\n",
       "<text text-anchor=\"middle\" x=\"170\" y=\"-83.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">Remaining (n=4)</text>\n",
       "</a>\n",
       "</g>\n",
       "</g>\n",
       "<!-- All Concepts (Layer 0)&#45;&gt;Remaining (≤ Layer 3) -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>All Concepts (Layer 0)-&gt;Remaining (≤ Layer 3)</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M125.76,-36.34C134.51,-45.97 145.37,-57.9 154.13,-67.55\"/>\n",
       "</g>\n",
       "<!-- diabetes/hypertension -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>diabetes/hypertension</title>\n",
       "<polygon fill=\"#e6f4ea\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"165.5,-172 44.5,-172 44.5,-136 165.5,-136 165.5,-172\"/>\n",
       "<text text-anchor=\"middle\" x=\"105\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">diabetes/hypertension</text>\n",
       "<text text-anchor=\"middle\" x=\"105\" y=\"-146\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">(Layer 4)</text>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 3)&#45;&gt;diabetes/hypertension -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>Remaining (≤ Layer 3)-&gt;diabetes/hypertension</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M152.92,-104.34C143.44,-113.97 131.69,-125.9 122.19,-135.55\"/>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 4) -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>Remaining (≤ Layer 4)</title>\n",
       "<g id=\"a_node5\"><a xlink:title=\"2023/2024\n",
       "March/July\n",
       "negative/positive\">\n",
       "<polygon fill=\"#eadbc8\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"284,-172 188,-172 188,-136 284,-136 284,-172\"/>\n",
       "<text text-anchor=\"middle\" x=\"236\" y=\"-151.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">Remaining (n=3)</text>\n",
       "</a>\n",
       "</g>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 3)&#45;&gt;Remaining (≤ Layer 4) -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>Remaining (≤ Layer 3)-&gt;Remaining (≤ Layer 4)</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M187.34,-104.34C196.97,-113.97 208.9,-125.9 218.55,-135.55\"/>\n",
       "</g>\n",
       "<!-- negative/positive -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>negative/positive</title>\n",
       "<polygon fill=\"#e6f4ea\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"224.5,-240 127.5,-240 127.5,-204 224.5,-204 224.5,-240\"/>\n",
       "<text text-anchor=\"middle\" x=\"176\" y=\"-225\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">negative/positive</text>\n",
       "<text text-anchor=\"middle\" x=\"176\" y=\"-214\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">(Layer 9)</text>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 4)&#45;&gt;negative/positive -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>Remaining (≤ Layer 4)-&gt;negative/positive</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M220.24,-172.34C211.49,-181.97 200.63,-193.9 191.87,-203.55\"/>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 9) -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>Remaining (≤ Layer 9)</title>\n",
       "<g id=\"a_node7\"><a xlink:title=\"2023/2024\n",
       "March/July\">\n",
       "<polygon fill=\"#eadbc8\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"343,-240 247,-240 247,-204 343,-204 343,-240\"/>\n",
       "<text text-anchor=\"middle\" x=\"295\" y=\"-219.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">Remaining (n=2)</text>\n",
       "</a>\n",
       "</g>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 4)&#45;&gt;Remaining (≤ Layer 9) -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>Remaining (≤ Layer 4)-&gt;Remaining (≤ Layer 9)</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M251.5,-172.34C260.11,-181.97 270.78,-193.9 279.4,-203.55\"/>\n",
       "</g>\n",
       "<!-- 2023/2024 -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>2023/2024</title>\n",
       "<polygon fill=\"#e6f4ea\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"284,-308 220,-308 220,-272 284,-272 284,-308\"/>\n",
       "<text text-anchor=\"middle\" x=\"252\" y=\"-293\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">2023/2024</text>\n",
       "<text text-anchor=\"middle\" x=\"252\" y=\"-282\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">(Layer 10)</text>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 9)&#45;&gt;2023/2024 -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>Remaining (≤ Layer 9)-&gt;2023/2024</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M283.7,-240.34C277.43,-249.97 269.65,-261.9 263.37,-271.55\"/>\n",
       "</g>\n",
       "<!-- March/July -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>March/July</title>\n",
       "<polygon fill=\"#e6f4ea\" stroke=\"#000000\" stroke-width=\"1.2\" points=\"370,-308 306,-308 306,-272 370,-272 370,-308\"/>\n",
       "<text text-anchor=\"middle\" x=\"338\" y=\"-293\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">March/July</text>\n",
       "<text text-anchor=\"middle\" x=\"338\" y=\"-282\" font-family=\"Helvetica,sans-Serif\" font-size=\"10.00\">(Layer 10)</text>\n",
       "</g>\n",
       "<!-- Remaining (≤ Layer 9)&#45;&gt;March/July -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>Remaining (≤ Layer 9)-&gt;March/July</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M306.3,-240.34C312.57,-249.97 320.35,-261.9 326.63,-271.55\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved vector outputs to figs/: concept_separation_tree.pdf, concept_separation_tree.svg\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from IPython.display import display, SVG\n",
    "import graphviz\n",
    "\n",
    "# ---- Graph with ICLR-style aesthetics ----\n",
    "LEAF_COLOR = '#E6F4EA'   # light green\n",
    "TRUNK_COLOR = '#EADBC8'  # light brown / beige\n",
    "\n",
    "dot = graphviz.Digraph('ConceptSeparationTree', comment='Concept Separation Hierarchy Tree')\n",
    "\n",
    "# Global graph settings\n",
    "dot.attr(\n",
    "    rankdir='BT',             # Bottom-to-top hierarchy\n",
    "    splines='spline',\n",
    "    concentrate='true',\n",
    "    nodesep='0.30',\n",
    "    ranksep='0.45',\n",
    "    pad='0.06',\n",
    "    fontname='Helvetica'\n",
    ")\n",
    "\n",
    "# Node style\n",
    "dot.attr('node',\n",
    "    shape='box',\n",
    "    style='filled',  # square corners for a more formal look\n",
    "    color='#000000',\n",
    "    fillcolor='#F7F7F7',\n",
    "    fontname='Helvetica',\n",
    "    fontsize='10',\n",
    "    margin='0.08,0.05',\n",
    "    penwidth='1.2'\n",
    ")\n",
    "\n",
    "# Edge style (no labels, no arrows)\n",
    "dot.attr('edge',\n",
    "    color='#000000',\n",
    "    arrowhead='none',\n",
    "    arrowsize='0.6',\n",
    "    fontname='Helvetica',\n",
    "    fontsize='9',\n",
    "    penwidth='1.0'\n",
    ")\n",
    "\n",
    "# ---- Build hierarchy ----\n",
    "sorted_layers = sorted([l for l in separation_results.values() if l != float('inf')])\n",
    "unique_sorted_layers = sorted(set(sorted_layers))\n",
    "\n",
    "remaining_concepts = set(concept_pairs.keys())\n",
    "parent_node_name = \"All Concepts (Layer 0)\"\n",
    "dot.node(parent_node_name, parent_node_name, fillcolor=TRUNK_COLOR)\n",
    "\n",
    "for layer_num in unique_sorted_layers:\n",
    "    concepts_now = {name for name, sep in separation_results.items() if sep == layer_num}\n",
    "\n",
    "    # Concepts separated at this layer → leaves (green)\n",
    "    for concept in concepts_now:\n",
    "        node_label = f\"{concept}\\n(Layer {layer_num})\"\n",
    "        dot.node(concept, node_label, fillcolor=LEAF_COLOR)\n",
    "        dot.edge(parent_node_name, concept)\n",
    "\n",
    "    # Update remaining and add a compact group node → trunk (brown)\n",
    "    remaining_concepts -= concepts_now\n",
    "    if remaining_concepts:\n",
    "        new_parent = f\"Remaining (≤ Layer {layer_num})\"\n",
    "        remaining_sorted = sorted(remaining_concepts)\n",
    "        node_label = f\"Remaining (n={len(remaining_sorted)})\"\n",
    "        tooltip_text = \"\\n\".join(remaining_sorted)\n",
    "        dot.node(new_parent, node_label, fillcolor=TRUNK_COLOR, tooltip=tooltip_text)\n",
    "        dot.edge(parent_node_name, new_parent)\n",
    "        parent_node_name = new_parent\n",
    "\n",
    "# Unseparated concepts are also leaves (green)\n",
    "for concept in {n for n, s in separation_results.items() if s == float('inf')}:\n",
    "    node_label = f\"{concept}\\n(not separated)\"\n",
    "    dot.node(concept, node_label, fillcolor=LEAF_COLOR, style='filled', color='#000000', penwidth='1.2')\n",
    "    dot.edge(parent_node_name, concept)\n",
    "\n",
    "# ---- Inline SVG display + vector exports ----\n",
    "try:\n",
    "    svg_bytes = dot.pipe(format='svg')\n",
    "    display(SVG(svg_bytes))\n",
    "\n",
    "    out_dir = '/codespace/3_representation-engineering/C_Representation-Manifolds/figs'\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    # Save vector formats for publication\n",
    "    dot.render(os.path.join(out_dir, 'concept_separation_tree'), format='pdf', view=False)\n",
    "    with open(os.path.join(out_dir, 'concept_separation_tree.svg'), 'wb') as f:\n",
    "        f.write(svg_bytes)\n",
    "    print(\"Saved vector outputs to figs/: concept_separation_tree.pdf, concept_separation_tree.svg\")\n",
    "except graphviz.backend.ExecutableNotFound:\n",
    "    print(\"Graphviz not found. Install from https://graphviz.org/download/\")\n",
    "    print(dot.source)\n",
    "except Exception as e:\n",
    "    print(\"Failed to render SVG:\", e)\n",
    "    print(\"DOT source:\\n\", dot.source)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
