{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Qwen2.5-Omni Attention Extraction & Region Analysis Pipeline\n",
    "\n",
    "This notebook extracts **audio → image cross-attention** from Qwen2.5-Omni-7B, then analyzes the attention distribution across four image quadrants (TL/TR/BL/BR).\n",
    "\n",
    "**Workflow:**\n",
    "1. Configure all paths in the cell below\n",
    "2. Download & load the model\n",
    "3. Extract per-layer attention matrices from each (audio, image) pair → save as `.pkl`\n",
    "4. Analyze PKL files: map attention to image quadrants → export as `.csv`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. User Configuration\n",
    "\n",
    "**Please fill in all paths below before running.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [User Configuration]\n# Set all paths, select AUDIO_MODE, and adjust parameters here.\n# No output files. This cell must be run before all other cells.\n# To switch audio mode, change AUDIO_MODE and re-run this cell.\n# ==============================================================================\n# USER CONFIGURATION - Please modify all paths below\n# ==============================================================================\nfrom pathlib import Path\n\n# --- Model paths ---\n# Directory for HuggingFace cache (model weights will be cached here)\nHF_HOME = Path(\"/path/to/hf_home\")\n# Local directory to store the downloaded Qwen2.5-Omni-7B model\nMODEL_DIR = Path(\"/path/to/hf_models/Qwen2.5-Omni-7B\")\n\n# --- Data root ---\n# Root directory containing your .xlsx files\nROOT_DIR = Path(\"/path/to/your/project_root\")\n\n# --- Image directory (shared across all audio modes) ---\nBASE_IMAGE_DIR_STR = \"/path/to/your/image_directory\"\n\n# --- Audio mode selection ---\n# Choose ONE of: \"before_tar\", \"before_er\", \"before_sp\"\n# This determines which audio directory and output paths are used\n# for BOTH attention extraction and region analysis.\nAUDIO_MODE = \"before_tar\"  # <-- change this to switch mode\n\n# Base directory that contains the three audio subdirectories\nBASE_AUDIO_ROOT = \"/path/to/your/audio_root\"\n\n# Mode config: mode name -> (audio_dir, pkl_subdir, csv_subdir)\n# Audio dir is under BASE_AUDIO_ROOT; pkl/csv subdirs are under ROOT_DIR.\n# You can edit subdirectory names if your folder structure differs.\n_AUDIO_MODE_CONFIG = {\n    \"before_tar\": {\n        \"audio_dir\": BASE_AUDIO_ROOT + \"/audio_cut_before_tar\",\n        \"pkl_subdir\": \"audio_cut_before_tar/attention_pkls_raw\",\n        \"csv_subdir\": \"audio_cut_before_tar/quad_outputs_fixedboxes\",\n    },\n    \"before_er\": {\n        \"audio_dir\": BASE_AUDIO_ROOT + \"/audio_cut_before_er\",\n        \"pkl_subdir\": \"audio_cut_before_er/attention_pkls_raw\",\n        \"csv_subdir\": \"audio_cut_before_er/quad_outputs_fixedboxes\",\n    },\n    \"before_sp\": {\n        \"audio_dir\": BASE_AUDIO_ROOT + \"/audio_cut_before_sp\",\n        \"pkl_subdir\": \"audio_cut_before_sp/attention_pkls_raw\",\n        \"csv_subdir\": \"audio_cut_before_sp/quad_outputs_fixedboxes\",\n    },\n}\n\nassert AUDIO_MODE in _AUDIO_MODE_CONFIG, \\\n    f\"Invalid AUDIO_MODE '{AUDIO_MODE}'. Choose from: {list(_AUDIO_MODE_CONFIG.keys())}\"\nACTIVE_CONDITION = _AUDIO_MODE_CONFIG[AUDIO_MODE]\n\n# --- Canvas & Region settings ---\n# Canvas size of your stimulus images (pixels)\nCANVAS_W, CANVAS_H = 1008, 756\n\n# Four quadrant regions [x1, y1, x2, y2] in pixel coordinates\nREGIONS = {\n    \"TL\": [84, 28, 392, 336],\n    \"TR\": [616, 28, 924, 336],\n    \"BL\": [84, 420, 392, 728],\n    \"BR\": [616, 420, 924, 728],\n}\n\n# --- Attention extraction settings ---\nHEAD_AGG = \"mean\"          # Head aggregation: 'mean' or 'max'\nDTYPE_SAVE = \"float16\"     # Save precision: 'float16' or 'float32'\nAUDIO_STRIDE = 1           # Audio token downsampling (1 = no downsampling)\nMAX_AUDIO_TOKENS = None    # Max audio tokens (None = no limit)\nLAYER_STRIDE = 1           # Layer downsampling (1 = all layers)\n\n# --- Region analysis settings ---\nREGION_AGG = \"mean\"        # Region aggregation: 'mean', 'sum', or 'max'\nDECIMALS = 6               # Decimal places in CSV output\nINCLUDE_REST = True        # True: output TL/TR/BL/BR/REST (5 regions); False: 4 regions only\nFAIL_FAST = True           # Stop on first error in batch processing\n\n# --- ViT patch size (Qwen-VL series typically uses 14) ---\nPATCH_SIZE = 14\n\n# --- Excel sheet name (None = first sheet) ---\nSHEET_NAME = None\n\n# --- Prompt template ---\nPROMPT_TEMPLATE = (\n\"\"\"## 指令\\n你是一个心理语言学实验的被试。接下来，你会首先看到四幅图，然后会听到一个和图片相关的句子（句子中会提到某些图片）。你的任务是：仔细看图片，并认真听句子，理解其意思。注意：某些试次结束时，会有一个关于本试次的问题。以下是图片和句子：<image><audio> \"\"\"\n)\n\nprint(f\"Configuration loaded. AUDIO_MODE = '{AUDIO_MODE}'\")\nprint(f\"  Audio dir:  {ACTIVE_CONDITION['audio_dir']}\")\nprint(f\"  PKL output: {ROOT_DIR / ACTIVE_CONDITION['pkl_subdir']}\")\nprint(f\"  CSV output: {ROOT_DIR / ACTIVE_CONDITION['csv_subdir']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# [Install Dependencies]\n# Installs/upgrades required Python packages.\n# No output files. You may need to restart the runtime after first install.\n!pip install -U \"transformers>=4.57\" torchcodec torchaudio openpyxl"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# [Import Libraries]\n# Imports all required Python modules.\n# No output files. If any import fails, check that the dependencies cell above was run.\nimport os\nimport math\nimport pickle\nfrom pathlib import Path\nfrom typing import Optional, List, Dict, Tuple\n\nimport pandas as pd\nimport torch\nimport torchaudio\nimport numpy as np\nfrom PIL import Image\nfrom transformers import (\n    Qwen2_5OmniThinkerForConditionalGeneration,\n    Qwen2_5OmniProcessor,\n)\nfrom huggingface_hub import snapshot_download\n\nprint(\"All imports successful.\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Download & Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# [Download Model]\n# Creates necessary directories and downloads Qwen2.5-Omni-7B model files\n# from HuggingFace Hub to MODEL_DIR. Skips files that already exist.\n# Output: model weights, config, and tokenizer files saved to MODEL_DIR (~15GB on first run).\nHF_HOME.mkdir(parents=True, exist_ok=True)\nMODEL_DIR.mkdir(parents=True, exist_ok=True)\nROOT_DIR.mkdir(parents=True, exist_ok=True)\n\nos.environ[\"HF_HOME\"] = str(HF_HOME)\nos.environ[\"HUGGINGFACE_HUB_CACHE\"] = str(HF_HOME / \"hub\")\n\nrepo_id = \"Qwen/Qwen2.5-Omni-7B\"\nprint(f\"Downloading {repo_id} to {MODEL_DIR} ...\")\nsnapshot_download(\n    repo_id=repo_id,\n    local_dir=str(MODEL_DIR),\n    local_dir_use_symlinks=False,\n    resume_download=True,\n    ignore_patterns=[],\n    allow_patterns=[\"*.json\", \"*.safetensors\", \"*.py\", \"*.md\", \"token*\", \"*.txt\"],\n)\nprint(\"Model downloaded to:\", MODEL_DIR)"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# [Load Model & Processor]\n# Loads the Qwen2.5-Omni-7B model and processor into GPU/CPU memory.\n# Uses attn_implementation=\"eager\" so that output_attentions works correctly.\n# No output files. This cell takes a few minutes on first load.\nMODEL_ID_OR_PATH = str(MODEL_DIR)\nDEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(\"Using device:\", DEVICE)\n\n\ndef load_model_and_processor():\n    \"\"\"\n    Load Qwen2.5-Omni-7B with attn_implementation='eager'\n    so that output_attentions is available.\n    \"\"\"\n    if torch.cuda.is_available():\n        major, _ = torch.cuda.get_device_capability()\n        dtype = torch.bfloat16 if major >= 8 else torch.float16\n        device_map = \"auto\"\n    else:\n        dtype = torch.float32\n        device_map = None\n\n    print(f\"Loading model from {MODEL_ID_OR_PATH} (attn_implementation='eager') ...\")\n    model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(\n        MODEL_ID_OR_PATH,\n        torch_dtype=dtype,\n        device_map=device_map,\n        low_cpu_mem_usage=True,\n        attn_implementation=\"eager\",\n    )\n    if device_map is None:\n        model = model.to(DEVICE)\n    model.eval()\n\n    print(f\"Loading processor from {MODEL_ID_OR_PATH} ...\")\n    processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_ID_OR_PATH)\n    print(\"Model and processor loaded.\")\n    return model, processor\n\n\nmodel, processor = load_model_and_processor()"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Part 1 — Attention Extraction Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [Define Attention Extraction Functions]\n# Defines helper functions and the main extract_attentions_from_excel() function.\n# No output files. These functions are called in the next cell.\n\ndef make_path(xlsx_dir: Path, base_dir_str: str, p_str: str) -> Path:\n    p = Path(p_str)\n    if p.is_absolute():\n        return p\n    if base_dir_str:\n        return (Path(base_dir_str) / p).resolve()\n    return (xlsx_dir / p).resolve()\n\n\ndef extract_attentions_from_excel(\n    xlsx_path: Path,\n    model,\n    processor,\n    base_audio_dir: str,\n    base_image_dir: str,\n    prompt_template: str,\n    *,\n    head_agg: str = \"mean\",\n    dtype_save: str = \"float16\",\n    audio_stride: int = 1,\n    max_audio_tokens: Optional[int] = None,\n    layer_stride: int = 1,\n    output_dir: Optional[Path] = None,\n    sheet_name: Optional[str] = None,\n):\n    \"\"\"\n    Export per-layer (Q_sub, K_sub) attention matrices (head-aggregated only)\n    for each (audio, image) pair listed in the Excel file.\n    \"\"\"\n    # --- Parameter validation ---\n    head_agg = str(head_agg).lower().strip()\n    assert head_agg in (\"mean\", \"max\"), f\"head_agg must be 'mean' or 'max'\"\n    dtype_save = str(dtype_save).lower().strip()\n    assert dtype_save in (\"float16\", \"float32\"), f\"dtype_save must be 'float16' or 'float32'\"\n    np_dtype = np.float16 if dtype_save == \"float16\" else np.float32\n\n    try:\n        torch.set_float32_matmul_precision(\"high\")\n    except Exception:\n        pass\n\n    # --- Output directory ---\n    if output_dir is None:\n        output_dir = Path(\".\") / \"attention_pkls_raw\"\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    # --- Build chat-template inputs ---\n    def _build_inputs_np(image_path: Path, audio_np: np.ndarray, sr: int):\n        assert isinstance(audio_np, np.ndarray) and audio_np.ndim == 1\n        conversations = [\n            {\n                \"role\": \"system\",\n                \"content\": [{\n                    \"type\": \"text\",\n                    \"text\": \"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.\"\n                }],\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"text\": prompt_template},\n                    {\"type\": \"image\", \"path\": str(image_path)},\n                    {\"type\": \"audio\", \"audio\": audio_np, \"sampling_rate\": int(sr)},\n                ],\n            },\n        ]\n        inputs = processor.apply_chat_template(\n            conversations,\n            add_generation_prompt=True,\n            tokenize=True,\n            return_tensors=\"pt\",\n            return_dict=True,\n            padding=True,\n        )\n        for k, v in inputs.items():\n            if isinstance(v, torch.Tensor):\n                inputs[k] = v.to(model.device)\n        return inputs\n\n    # --- Locate modal spans ---\n    def _locate_modal_spans(inputs):\n        ids = inputs[\"input_ids\"][0].detach().cpu().tolist()\n        cfg = model.config\n        tok = processor.tokenizer\n\n        def first_pos(token_id):\n            if token_id is None: return None\n            try: return ids.index(int(token_id))\n            except ValueError: return None\n\n        def get_token_id(attr, fallbacks):\n            v = getattr(cfg, attr, None)\n            if v is not None: return int(v)\n            for s in fallbacks:\n                try:\n                    tid = tok.convert_tokens_to_ids(s)\n                    if isinstance(tid, int) and tid != tok.unk_token_id:\n                        return tid\n                except Exception:\n                    pass\n            return None\n\n        audio_start_id  = get_token_id(\"audio_start_token_id\",  [\"<|audio_start|>\", \"<audio>\", \"<|AUDIO_START|>\"])\n        audio_end_id    = get_token_id(\"audio_end_token_id\",    [\"<|audio_end|>\", \"<|AUDIO_END|>\", \"</audio>\"])\n        vision_start_id = get_token_id(\"vision_start_token_id\", [\"<|vision_start|>\", \"<image>\", \"<|image_start|>\", \"<|IM_START|>\"])\n        vision_end_id   = get_token_id(\"vision_end_token_id\",   [\"<|vision_end|>\", \"<image_end>\", \"<|image_end|>\", \"<|IM_END|>\"])\n\n        a_start = first_pos(audio_start_id)\n        v_start = first_pos(vision_start_id)\n        if a_start is None or v_start is None:\n            raise RuntimeError(f\"Cannot find audio/vision start tokens (audio={a_start}, vision={v_start}).\")\n\n        # Audio length\n        if \"audio_feature_lengths\" in inputs:\n            audio_len = int(inputs[\"audio_feature_lengths\"][0].detach().cpu().item())\n        elif \"feature_attention_mask\" in inputs:\n            audio_len = int(inputs[\"feature_attention_mask\"][0].detach().cpu().sum().item())\n        elif \"input_features\" in inputs:\n            audio_len = int(inputs[\"input_features\"].shape[-1])\n        else:\n            raise RuntimeError(\"Cannot infer audio length.\")\n        a_end = a_start + audio_len\n\n        # Image grid\n        g = inputs[\"image_grid_thw\"].detach().cpu()\n        while g.ndim > 1: g = g[0]\n        if g.numel() != 3:\n            raise RuntimeError(f\"Unexpected image_grid_thw shape: {inputs['image_grid_thw'].shape}\")\n        T, H_p, W_p = [int(x) for x in g.tolist()]\n        grid_total = int(T * H_p * W_p)\n\n        v_end_before_audio = None\n        if vision_end_id is not None:\n            try:\n                cand = ids.index(int(vision_end_id), v_start + 1)\n                if cand < a_start:\n                    v_end_before_audio = cand\n            except ValueError:\n                v_end_before_audio = None\n\n        pre_boundary_space = max(0, a_start - (v_start + 1))\n        return (a_start, a_end, v_start, v_end_before_audio, (T, H_p, W_p), grid_total, pre_boundary_space)\n\n    # --- Compose image columns (two segments) ---\n    def _compose_image_columns_and_meta(\n        a_start, a_end, v_start, v_end_before_audio, T, H_p, W_p,\n        K_total_after_forward: int\n    ):\n        m_default = 2\n        H_eff_def = math.ceil(H_p / m_default)\n        W_eff_def = math.ceil(W_p / m_default)\n        K_eff_theory_def = T * H_eff_def * W_eff_def\n\n        post_boundary_space = max(0, K_total_after_forward - (a_end + 1))\n        pre_boundary_space = max(0, a_start - (v_start + 1))\n        pre_eff_len_used  = min(pre_boundary_space, K_eff_theory_def)\n        post_eff_len_used = min(max(0, K_eff_theory_def - pre_eff_len_used), post_boundary_space)\n\n        image_indices_pre  = torch.arange(v_start + 1, v_start + 1 + pre_eff_len_used, dtype=torch.long)\n        post_start         = a_end + 1\n        image_indices_post = torch.arange(post_start, post_start + post_eff_len_used, dtype=torch.long) if post_eff_len_used > 0 else torch.empty(0, dtype=torch.long)\n\n        image_indices_all = torch.cat([image_indices_pre, image_indices_post], dim=0)\n        K_sub = int(image_indices_all.numel())\n        image_k_rel_eff = torch.arange(0, K_sub, dtype=torch.long)\n\n        ratio = (T * H_p * W_p) / max(1, K_sub)\n        m_infer = int(round(math.sqrt(ratio)))\n        m_infer = max(1, min(4, m_infer))\n\n        merge_meta = {\n            \"merge_size_guess\": int(m_default),\n            \"merge_source\": \"assumed_default_2\",\n            \"merge_size_inferred_from_counts\": int(m_infer),\n            \"eff_grid_thw\": (int(T), int(H_eff_def), int(W_eff_def)),\n            \"K_eff_theory\": int(K_eff_theory_def),\n            \"K_sub\": int(K_sub),\n            \"coverage_eff\": float(K_sub / max(1, K_eff_theory_def))\n        }\n        alloc_meta = {\n            \"pre_boundary_space\": int(pre_boundary_space),\n            \"post_boundary_space\": int(post_boundary_space),\n            \"pre_eff_len_used\": int(pre_eff_len_used),\n            \"post_eff_len_used\": int(post_eff_len_used)\n        }\n        return (image_indices_pre, image_indices_post, image_indices_all, image_k_rel_eff,\n                (T, H_eff_def, W_eff_def), merge_meta, alloc_meta)\n\n    # --- Forward with GPU/CPU fallback ---\n    def _forward_with_fallback(inputs):\n        try:\n            with torch.inference_mode():\n                out = model(**inputs, output_attentions=True)\n            if torch.cuda.is_available():\n                torch.cuda.synchronize()\n            return out, model.device, None\n        except RuntimeError as e:\n            msg = str(e)\n            if any(key in msg for key in (\"device-side assert\", \"CUDA error\", \"out of memory\", \"illegal memory access\")):\n                print(f\"  [FALLBACK] GPU error: {msg.split(chr(10))[0]}\")\n                print(\"  -> Falling back to CPU for this sample.\")\n                inputs_cpu = {k: (v.detach().cpu() if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}\n                model_cpu = model.to(\"cpu\")\n                with torch.inference_mode():\n                    out_cpu = model_cpu(**inputs_cpu, output_attentions=True)\n                model_cpu.to(torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"))\n                if torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n                return out_cpu, torch.device(\"cpu\"), f\"forward:{msg.splitlines()[0]}\"\n            else:\n                raise\n\n    # ==================== Main processing loop ====================\n    print(f\"\\n>>> Processing: {xlsx_path.name}\")\n    try:\n        df = pd.read_excel(xlsx_path, engine=\"openpyxl\") if sheet_name is None else pd.read_excel(xlsx_path, sheet_name=sheet_name, engine=\"openpyxl\")\n    except Exception as e:\n        print(f\"  [ERROR] Cannot read Excel: {e}\")\n        return\n\n    for c in [\"Item\", \"Audio_File_new\", \"Image_File\"]:\n        if c not in df.columns:\n            print(f\"  [ERROR] {xlsx_path.name} missing column: {c}. Skipping.\")\n            return\n\n    target_sr = processor.feature_extractor.sampling_rate\n    print(f\"    Target sampling rate: {target_sr} Hz\")\n    resampler_cache: Dict[int, torchaudio.transforms.Resample] = {}\n    xlsx_dir = xlsx_path.parent\n\n    model_id = (\n        str(getattr(model.config, \"_name_or_path\", \"\")) or\n        str(getattr(model, \"name_or_path\", \"\")) or\n        MODEL_ID_OR_PATH\n    )\n\n    for _, row in df.iterrows():\n        item = row[\"Item\"]\n        try:\n            audio_path = make_path(xlsx_dir, base_audio_dir, str(row[\"Audio_File_new\"]).strip())\n            image_path = make_path(xlsx_dir, base_image_dir, str(row[\"Image_File\"]).strip())\n        except Exception as e:\n            print(f\"  [SKIP] Item={item} path error: {e}\"); continue\n        if not audio_path.exists():\n            print(f\"  [SKIP] Item={item} audio not found: {audio_path}\"); continue\n        if not image_path.exists():\n            print(f\"  [SKIP] Item={item} image not found: {image_path}\"); continue\n\n        safe_item = str(item).replace('/', '_').replace('\\\\', '_')\n        pkl_name = f\"Item_{safe_item}__{audio_path.stem}__{image_path.stem}.pkl\"\n        pkl_path = output_dir / pkl_name\n        if pkl_path.exists():\n            print(f\"  [SKIP] Item={item} PKL already exists: {pkl_name}\"); continue\n\n        print(f\"  --- Processing Item={item} ---\")\n        try:\n            # Load audio\n            try:\n                audio_wav, sr0 = torchaudio.load(audio_path)\n            except Exception as e:\n                print(f\"  [SKIP] Item={item} cannot load audio: {e}\"); continue\n            if sr0 != target_sr:\n                if sr0 not in resampler_cache:\n                    resampler_cache[sr0] = torchaudio.transforms.Resample(sr0, target_sr).to(audio_wav.device)\n                audio_wav = resampler_cache[sr0](audio_wav)\n            if audio_wav.shape[0] > 1: audio_wav = audio_wav.mean(dim=0, keepdim=True)\n            audio_np = audio_wav.squeeze(0).detach().cpu().numpy().astype(np.float32, copy=False)\n\n            # Verify image\n            try:\n                _ = Image.open(image_path).convert(\"RGB\")\n            except Exception as e:\n                print(f\"  [SKIP] Item={item} cannot load image: {e}\"); continue\n\n            # Build inputs\n            inputs = _build_inputs_np(image_path=image_path, audio_np=audio_np, sr=target_sr)\n\n            # Locate modal spans\n            try:\n                a_start, a_end, v_start, v_end_before_audio, (T, H_p, W_p), grid_total, pre_boundary_space = _locate_modal_spans(inputs)\n            except Exception as e:\n                print(f\"  [SKIP] Item={item} cannot locate modal spans: {e}\"); continue\n\n            # Forward pass\n            out, attn_device, forward_fallback = _forward_with_fallback(inputs)\n            attentions = out.attentions\n            if attentions is None:\n                raise RuntimeError(\"Model did not return attentions; ensure attn_implementation='eager'.\")\n\n            H_heads, Q_total, K_total = attentions[0].squeeze(0).shape\n\n            # Compose image columns\n            (image_idx_pre, image_idx_post, image_idx_all,\n             image_k_rel_eff, eff_grid_thw, merge_meta, alloc_meta) = _compose_image_columns_and_meta(\n                a_start, a_end, v_start, v_end_before_audio, T, H_p, W_p, K_total_after_forward=K_total\n            )\n\n            # Clip to K_total range\n            ok_mask = (image_idx_all >= 0) & (image_idx_all < K_total)\n            image_idx_all = image_idx_all[ok_mask]\n            image_k_rel_eff = image_k_rel_eff[ok_mask]\n            if image_idx_pre.numel() > 0:\n                image_idx_pre = image_idx_pre[(image_idx_pre >= 0) & (image_idx_pre < K_total)]\n            if image_idx_post.numel() > 0:\n                image_idx_post = image_idx_post[(image_idx_post >= 0) & (image_idx_post < K_total)]\n\n            # Audio Q segment\n            audio_idx_cpu = torch.arange(a_start + 1, a_start + 1 + (a_end - a_start), dtype=torch.long)\n            audio_idx_cpu = audio_idx_cpu[(audio_idx_cpu >= 0) & (audio_idx_cpu < Q_total)]\n            if audio_stride > 1:\n                audio_idx_cpu = audio_idx_cpu[::int(audio_stride)]\n            if max_audio_tokens is not None and audio_idx_cpu.numel() > int(max_audio_tokens):\n                audio_idx_cpu = audio_idx_cpu[:int(max_audio_tokens)]\n            if audio_idx_cpu.numel() == 0 or image_idx_all.numel() == 0:\n                print(f\"  [SKIP] Item={item} empty indices: Q_total={Q_total}, K_total={K_total}\"); continue\n\n            Q_sub = int(audio_idx_cpu.numel()); K_sub = int(image_idx_all.numel())\n            print(f\"    layers={len(attentions)} (using {len(range(0, len(attentions), int(max(1, layer_stride))))}), Q_sub={Q_sub}, K_sub={K_sub}, H={H_heads}\")\n\n            device = attn_device if torch.cuda.is_available() else torch.device(\"cpu\")\n            use_cpu_post = (device.type == \"cpu\")\n            audio_rows_gpu = audio_idx_cpu.to(device)\n            image_cols_gpu = image_idx_all.to(device)\n\n            # Extract per-layer attention\n            layer_ids = list(range(0, len(attentions), int(max(1, layer_stride))))\n            attn_mats: List[np.ndarray] = []\n            post_fallback_reason = None\n\n            for li in layer_ids:\n                L_bhqk = attentions[li][0]\n                if not use_cpu_post:\n                    try:\n                        sub = L_bhqk.index_select(1, audio_rows_gpu).index_select(2, image_cols_gpu)\n                        M = sub.mean(dim=0) if head_agg == \"mean\" else sub.amax(dim=0)\n                        arr = M.detach().to(\"cpu\", dtype=torch.float32).numpy().astype(np_dtype, copy=False)\n                        attn_mats.append(arr)\n                        del sub, M\n                        if torch.cuda.is_available(): torch.cuda.empty_cache()\n                        continue\n                    except RuntimeError as e:\n                        post_fallback_reason = f\"postproc:{str(e).splitlines()[0]}\"\n                        use_cpu_post = True\n\n                layer_cpu = L_bhqk.detach().to(\"cpu\", dtype=torch.float32)\n                sub_cpu = layer_cpu.index_select(1, audio_idx_cpu).index_select(2, image_idx_all)\n                M_cpu = sub_cpu.mean(dim=0) if head_agg == \"mean\" else sub_cpu.amax(dim=0)\n                attn_mats.append(M_cpu.numpy().astype(np_dtype, copy=False))\n                del layer_cpu, sub_cpu, M_cpu\n\n            # Effective 2D coordinates\n            T_eff, H_eff, W_eff = eff_grid_thw\n            k_rel_eff_np = image_k_rel_eff.detach().cpu().numpy()\n            t_eff = (k_rel_eff_np // (H_eff * W_eff)).astype(np.int32) if T_eff > 1 else np.zeros_like(k_rel_eff_np, dtype=np.int32)\n            hw_rem = k_rel_eff_np % (H_eff * W_eff)\n            h_eff = (hw_rem // W_eff).astype(np.int32)\n            w_eff = (hw_rem %  W_eff).astype(np.int32)\n            eff_coords = np.stack([t_eff, h_eff, w_eff], axis=1)\n\n            # Assemble payload\n            payload = {\n                \"schema_version\": \"qk_attn.v2\",\n                \"model_id\": model_id,\n                \"head_agg\": head_agg,\n                \"dtype\": dtype_save,\n                \"item\": item,\n                \"audio_path\": str(audio_path),\n                \"image_path\": str(image_path),\n                \"Q_total\": int(Q_total),\n                \"K_total\": int(K_total),\n                \"audio_indices\": audio_idx_cpu.detach().cpu().tolist(),\n                \"image_indices_pre\": image_idx_pre.detach().cpu().tolist(),\n                \"image_indices_post\": image_idx_post.detach().cpu().tolist(),\n                \"image_indices\": image_idx_all.detach().cpu().tolist(),\n                \"image_grid_thw\": (int(T), int(H_p), int(W_p)),\n                \"eff_grid_thw\": (int(T_eff), int(H_eff), int(W_eff)),\n                \"image_k_rel_eff\": image_k_rel_eff.detach().cpu().tolist(),\n                \"eff_coords_per_token\": eff_coords.tolist(),\n                \"layer_ids\": [int(x) for x in layer_ids],\n                \"attn_per_layer\": attn_mats,\n                \"notes\": {\n                    \"device_forward\": str(attn_device),\n                    \"device_postproc\": \"cpu\" if (use_cpu_post or attn_device.type == \"cpu\") else \"cuda\",\n                    \"forward_fallback\": forward_fallback,\n                    \"postproc_fallback\": post_fallback_reason,\n                    \"H_heads\": int(H_heads),\n                    \"AUDIO_STRIDE\": int(audio_stride),\n                    \"MAX_AUDIO_TOKENS\": (None if max_audio_tokens is None else int(max_audio_tokens)),\n                    \"LAYER_STRIDE\": int(layer_stride),\n                    \"a_start\": int(a_start), \"a_end\": int(a_end),\n                    \"v_start\": int(v_start),\n                    \"v_end_before_audio\": (None if v_end_before_audio is None else int(v_end_before_audio)),\n                    \"grid_total\": int(grid_total),\n                    **merge_meta,\n                    **alloc_meta,\n                }\n            }\n\n            with open(pkl_path, \"wb\") as f:\n                pickle.dump(payload, f, protocol=4)\n            print(f\"  [+] Saved: {pkl_name}\")\n\n        except Exception as e:\n            print(f\"  [FAIL] Item={item}: {e}\")\n\n    print(f\"\\n<<< Finished: {xlsx_path.name}\")\n\n\nprint(\"Attention extraction functions defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 6. Run Attention Extraction\n\nUses the selected `AUDIO_MODE` to determine audio directory and PKL output path."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [Run Attention Extraction]\n# Reads all .xlsx files in ROOT_DIR, processes each (audio, image) pair\n# through the model, and saves per-layer attention matrices.\n# Output files: one .pkl file per item, saved to {ROOT_DIR}/{pkl_subdir}/\n# Skips items whose .pkl already exists (delete old .pkl to re-extract).\n\nxlsx_files = sorted([p for p in ROOT_DIR.glob(\"*.xlsx\") if not p.name.startswith(\"~$\")])\n\nif not xlsx_files:\n    print(f\"No .xlsx files found in {ROOT_DIR.resolve()}\")\nelse:\n    audio_dir = ACTIVE_CONDITION[\"audio_dir\"]\n    pkl_out = ROOT_DIR / ACTIVE_CONDITION[\"pkl_subdir\"]\n    pkl_out.mkdir(parents=True, exist_ok=True)\n\n    print(f\"AUDIO_MODE:  {AUDIO_MODE}\")\n    print(f\"Audio dir:   {audio_dir}\")\n    print(f\"PKL output:  {pkl_out}\")\n    print(f\"Excel files: {len(xlsx_files)}\")\n    print(\"=\" * 60)\n\n    for x_file in xlsx_files:\n        extract_attentions_from_excel(\n            xlsx_path=x_file,\n            model=model,\n            processor=processor,\n            base_audio_dir=audio_dir,\n            base_image_dir=BASE_IMAGE_DIR_STR,\n            prompt_template=PROMPT_TEMPLATE,\n            head_agg=HEAD_AGG,\n            dtype_save=DTYPE_SAVE,\n            audio_stride=AUDIO_STRIDE,\n            max_audio_tokens=MAX_AUDIO_TOKENS,\n            layer_stride=LAYER_STRIDE,\n            output_dir=pkl_out,\n            sheet_name=SHEET_NAME,\n        )\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\nprint(\"\\n=== Attention extraction complete ===\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 7. Part 2 — Region Analysis Functions\n",
    "\n",
    "Reads the `.pkl` files from Part 1, maps attention to four image quadrants (TL/TR/BL/BR), and exports normalized CSV files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [Define Region Analysis Functions]\n# Defines functions to map image tokens to quadrant regions (TL/TR/BL/BR),\n# aggregate attention per region, normalize, and export to CSV.\n# No output files. These functions are called in the next cell.\n\n# Fixed box constants (derived from REGIONS and CANVAS settings)\nCELL_PX = PATCH_SIZE * 2  # effective cell pixel size (= 14*2 = 28)\nFIXED_BOXES = REGIONS.copy()\n\nH_EFF_EXPECT = CANVAS_H // CELL_PX\nW_EFF_EXPECT = CANVAS_W // CELL_PX\nTOKENS_PER_BOX_EXPECT = (\n    (FIXED_BOXES[\"TL\"][2] - FIXED_BOXES[\"TL\"][0]) // CELL_PX\n) * (\n    (FIXED_BOXES[\"TL\"][3] - FIXED_BOXES[\"TL\"][1]) // CELL_PX\n)\nTOTAL_BOX_TOKENS_EXPECT = TOKENS_PER_BOX_EXPECT * 4\n\nprint(f\"Effective grid: {H_EFF_EXPECT} x {W_EFF_EXPECT} (cell={CELL_PX}px)\")\nprint(f\"Tokens per box: {TOKENS_PER_BOX_EXPECT}, total (4 boxes): {TOTAL_BOX_TOKENS_EXPECT}\")\n\n\ndef _load_v2_strict(pkl_path: str) -> dict:\n    d = pickle.load(open(pkl_path, \"rb\"))\n    if d.get(\"schema_version\", \"\") != \"qk_attn.v2\":\n        raise RuntimeError(f\"{pkl_path}: requires schema_version='qk_attn.v2', got {d.get('schema_version')}\")\n    if \"eff_grid_thw\" not in d or \"image_k_rel_eff\" not in d or \"eff_coords_per_token\" not in d:\n        raise RuntimeError(f\"{pkl_path}: missing v2 fields eff_grid_thw/image_k_rel_eff/eff_coords_per_token\")\n    return d\n\n\ndef _assert_canvas_and_grid(d: dict) -> Tuple[int, int, Tuple[int, int, int]]:\n    W0, H0 = Image.open(d[\"image_path\"]).size\n    if (W0, H0) != (CANVAS_W, CANVAS_H):\n        raise RuntimeError(f\"Image resolution mismatch: expected {CANVAS_W}x{CANVAS_H}, got {W0}x{H0}\")\n    T_eff, H_eff, W_eff = d[\"eff_grid_thw\"]\n    if not (T_eff == 1 and H_eff == H_EFF_EXPECT and W_eff == W_EFF_EXPECT):\n        raise RuntimeError(f\"Grid mismatch: expected (1,{H_EFF_EXPECT},{W_EFF_EXPECT}), got ({T_eff},{H_eff},{W_eff})\")\n    if W0 // W_eff != CELL_PX or H0 // H_eff != CELL_PX:\n        raise RuntimeError(f\"Cell size mismatch: expected {CELL_PX}px\")\n    return W0, H0, (T_eff, H_eff, W_eff)\n\n\ndef _boxes_to_eff_ranges_strict(W_eff: int, H_eff: int) -> Dict[str, Tuple[range, range]]:\n    ranges: Dict[str, Tuple[range, range]] = {}\n    for name, (x1, y1, x2, y2) in FIXED_BOXES.items():\n        for v in (x1, y1, x2, y2):\n            if v % CELL_PX != 0:\n                raise RuntimeError(f\"Box {name} boundary {v} is not a multiple of {CELL_PX}\")\n        gx1 = x1 // CELL_PX; gx2 = x2 // CELL_PX\n        gy1 = y1 // CELL_PX; gy2 = y2 // CELL_PX\n        if not (0 <= gx1 < gx2 <= W_eff and 0 <= gy1 < gy2 <= H_eff):\n            raise RuntimeError(f\"Box {name} out of grid range\")\n        cols = gx2 - gx1; rows = gy2 - gy1\n        if rows * cols != TOKENS_PER_BOX_EXPECT:\n            raise RuntimeError(f\"Box {name} token count: {rows}x{cols}={rows*cols}, expected {TOKENS_PER_BOX_EXPECT}\")\n        ranges[name] = (range(gy1, gy2), range(gx1, gx2))\n    # Check no overlap\n    seen = set()\n    for name, (rows, cols) in ranges.items():\n        for y in rows:\n            for x in cols:\n                if (y, x) in seen:\n                    raise RuntimeError(f\"Overlapping cell at ({y},{x})\")\n                seen.add((y, x))\n    return ranges\n\n\ndef _ranges_to_krel(rows: range, cols: range, W_eff: int) -> np.ndarray:\n    return np.asarray([y * W_eff + x for y in rows for x in cols], dtype=np.int64)\n\n\ndef _build_region_cols_from_pkl(d: dict, include_rest: bool = False) -> Dict[str, np.ndarray]:\n    _, _, (T_eff, H_eff, W_eff) = _assert_canvas_and_grid(d)\n    if T_eff != 1:\n        raise RuntimeError(f\"Only static images supported (T_eff=1), got T_eff={T_eff}\")\n    ranges = _boxes_to_eff_ranges_strict(W_eff=W_eff, H_eff=H_eff)\n    k_rel_have = np.asarray(d[\"image_k_rel_eff\"], dtype=np.int64)\n    K_sub = int(k_rel_have.shape[0])\n    idx_map: Dict[int, int] = {}\n    for col, k in enumerate(k_rel_have.tolist()):\n        if k in idx_map:\n            raise RuntimeError(f\"Duplicate k_rel_eff: k={k}\")\n        idx_map[k] = col\n\n    region_cols: Dict[str, np.ndarray] = {}\n    all_cols_list = []\n    for name, (rows, cols) in ranges.items():\n        need_k = _ranges_to_krel(rows, cols, W_eff)\n        miss = [int(k) for k in need_k.tolist() if k not in idx_map]\n        if miss:\n            raise RuntimeError(f\"PKL missing {len(miss)} tokens for box {name}\")\n        cols_in_sub = np.asarray([idx_map[int(k)] for k in need_k.tolist()], dtype=np.int64)\n        if cols_in_sub.size != TOKENS_PER_BOX_EXPECT:\n            raise RuntimeError(f\"Box {name} column count: {cols_in_sub.size}, expected {TOKENS_PER_BOX_EXPECT}\")\n        region_cols[name] = cols_in_sub\n        all_cols_list.append(cols_in_sub)\n\n    all_cols_cat = np.concatenate(all_cols_list, axis=0)\n    if all_cols_cat.size != TOTAL_BOX_TOKENS_EXPECT:\n        raise RuntimeError(f\"Total box columns: {all_cols_cat.size}, expected {TOTAL_BOX_TOKENS_EXPECT}\")\n    unique_cols = np.unique(all_cols_cat)\n    if unique_cols.size != TOTAL_BOX_TOKENS_EXPECT:\n        raise RuntimeError(f\"Overlapping columns detected\")\n\n    if include_rest:\n        full = np.arange(K_sub, dtype=np.int64)\n        region_cols[\"REST\"] = np.setdiff1d(full, unique_cols, assume_unique=False)\n\n    return region_cols\n\n\ndef _aggregate_layers_and_normalize(\n    d: dict,\n    region_cols: Dict[str, np.ndarray],\n    agg: str = \"mean\",\n    decimals: int = 6\n) -> Tuple[pd.DataFrame, pd.DataFrame]:\n    attn_layers = d[\"attn_per_layer\"]\n    layer_ids = d.get(\"layer_ids\", list(range(len(attn_layers))))\n    audio_abs = np.asarray(d[\"audio_indices\"], dtype=np.int64)\n    Q_sub, K_sub = attn_layers[0].shape\n\n    order = [\"TL\", \"TR\", \"BL\", \"BR\"] + ([\"REST\"] if \"REST\" in region_cols else [])\n    R = len(order)\n\n    def _agg_one(A):\n        A = A.astype(np.float32, copy=False)\n        out = np.zeros((Q_sub, R), dtype=np.float32)\n        for j, name in enumerate(order):\n            cols = region_cols[name]\n            if cols.size == 0:\n                out[:, j] = 0.0; continue\n            sub = A[:, cols]\n            if agg == \"sum\":    v = sub.sum(axis=1)\n            elif agg == \"max\":  v = sub.max(axis=1)\n            else:               v = sub.mean(axis=1)\n            out[:, j] = v\n        rs = out.sum(axis=1, keepdims=True)\n        zero_mask = (rs == 0.0)\n        rs[zero_mask] = 1.0\n        out = out / rs\n        if np.any(zero_mask):\n            out[zero_mask.flatten(), :] = 1.0 / float(R)\n        return out\n\n    per_layer = [_agg_one(A) for A in attn_layers]\n    layer_labels = [f\"L{lid+1:02d}\" for lid in layer_ids]\n    q_idx = np.arange(Q_sub, dtype=int)\n\n    # Human-readable\n    data_h = {\"q_idx\": q_idx.tolist(), \"audio_abs_index\": audio_abs.tolist()}\n    for mat, lab in zip(per_layer, layer_labels):\n        rounded = np.round(mat, decimals)\n        col_str = []\n        for i in range(rounded.shape[0]):\n            parts = [f\"{name}={rounded[i, j]:.{decimals}f}\" for j, name in enumerate(order)]\n            col_str.append(\";\".join(parts))\n        data_h[lab] = np.array(col_str, dtype=object)\n    df_human = pd.DataFrame(data_h)\n\n    # Machine-friendly\n    data_m = {\"q_idx\": q_idx.tolist(), \"audio_abs_index\": audio_abs.tolist()}\n    for mat, lab in zip(per_layer, layer_labels):\n        out = np.round(mat, decimals=decimals)\n        for j, name in enumerate(order):\n            data_m[f\"{lab}_{name}\"] = out[:, j]\n    df_machine = pd.DataFrame(data_m)\n\n    return df_human, df_machine\n\n\ndef process_one_pkl_fixed_boxes(\n    pkl_path: Path,\n    out_root: Path,\n    agg: str = \"mean\",\n    decimals: int = 6,\n    include_rest: bool = False\n) -> Tuple[Path, Path]:\n    d = _load_v2_strict(str(pkl_path))\n    region_cols = _build_region_cols_from_pkl(d, include_rest=include_rest)\n\n    sizes = {k: int(len(v)) for k, v in region_cols.items()}\n    total_known = sizes.get(\"TL\", 0) + sizes.get(\"TR\", 0) + sizes.get(\"BL\", 0) + sizes.get(\"BR\", 0)\n    rest_sz = sizes.get(\"REST\", 0)\n    K_sub = len(d[\"image_k_rel_eff\"])\n    print(f\"[{pkl_path.name}] TL/TR/BL/BR={total_known}, REST={rest_sz}, total={total_known + rest_sz}/{K_sub}\")\n\n    df_human, df_machine = _aggregate_layers_and_normalize(d, region_cols, agg=agg, decimals=decimals)\n\n    out_human_dir = out_root / \"human_readable\"\n    out_machine_dir = out_root / \"machine_friendly\"\n    out_human_dir.mkdir(parents=True, exist_ok=True)\n    out_machine_dir.mkdir(parents=True, exist_ok=True)\n\n    stem = pkl_path.stem\n    tag = \"fixedboxes5\" if (\"REST\" in region_cols) else \"fixedboxes4\"\n    human_csv = out_human_dir / f\"{stem}.{tag}_{agg}_norm.human.csv\"\n    machine_csv = out_machine_dir / f\"{stem}.{tag}_{agg}_norm.machine.csv\"\n\n    df_human.to_csv(human_csv, index=False, encoding=\"utf-8\")\n    df_machine.to_csv(machine_csv, index=False, encoding=\"utf-8\")\n    print(f\"  -> HUMAN:   {human_csv}\")\n    print(f\"  -> MACHINE: {machine_csv}\")\n    return human_csv, machine_csv\n\n\ndef process_dir_fixed_boxes(\n    pkl_dir: Path,\n    out_root: Path,\n    pattern: str = \"*.pkl\",\n    agg: str = \"mean\",\n    decimals: int = 6,\n    fail_fast: bool = True,\n    include_rest: bool = False\n) -> None:\n    pkl_files = sorted(pkl_dir.glob(pattern))\n    if not pkl_files:\n        print(f\"[WARN] No PKL found in {pkl_dir} (pattern={pattern})\")\n        return\n    print(f\"[INFO] Found {len(pkl_files)} PKL in {pkl_dir}\")\n    out_root.mkdir(parents=True, exist_ok=True)\n\n    ok, fail = 0, 0\n    for p in pkl_files:\n        try:\n            process_one_pkl_fixed_boxes(p, out_root=out_root, agg=agg, decimals=decimals, include_rest=include_rest)\n            ok += 1\n        except Exception as e:\n            print(f\"[FAIL] {p.name}: {e}\")\n            fail += 1\n            if fail_fast:\n                raise\n    print(f\"[DONE] success={ok}, failed={fail}\")\n\n\nprint(\"Region analysis functions defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 8. Run Region Analysis\n\nUses the same `AUDIO_MODE` to locate PKL files and export CSV results."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [Run Region Analysis]\n# Reads .pkl files from the extraction step, maps attention to four quadrants,\n# and exports normalized results as CSV.\n# Output files (saved to {ROOT_DIR}/{csv_subdir}/):\n#   - human_readable/*.human.csv   (one string column per layer)\n#   - machine_friendly/*.machine.csv (separate numeric columns per layer per region)\n\npkl_dir = ROOT_DIR / ACTIVE_CONDITION[\"pkl_subdir\"]\ncsv_out = ROOT_DIR / ACTIVE_CONDITION[\"csv_subdir\"]\n\nprint(f\"AUDIO_MODE:  {AUDIO_MODE}\")\nprint(f\"PKL input:   {pkl_dir}\")\nprint(f\"CSV output:  {csv_out}\")\nprint(\"=\" * 60)\n\nprocess_dir_fixed_boxes(\n    pkl_dir=pkl_dir,\n    out_root=csv_out,\n    pattern=\"*.pkl\",\n    agg=REGION_AGG,\n    decimals=DECIMALS,\n    fail_fast=FAIL_FAST,\n    include_rest=INCLUDE_REST,\n)\n\nprint(\"\\n=== Region analysis complete ===\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}