{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Import libraries\"\"\"\n",
    "from typing import Optional, Union, Tuple, List, Callable, Dict\n",
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "import torch.nn.functional as nnf\n",
    "import numpy as np\n",
    "import abc\n",
    "import ptp_utils\n",
    "import seq_aligner\n",
    "import os\n",
    "import csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the below codes block by block to save files that will be used in the rest of the codes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Your prompt and description names here\"\"\"\n",
    "prompt = \"merged\"                           # Specify the name of your generation prompt file saved under \"./prompt\"; [merged: 1000 prompts from ImageNet 1000 classes + 1100 prompts from Prompt-Hero]\n",
    "description = \"340_final_text_descriptions\" # The filename of your concept-words set saved under \"./descriptions\"\n",
    "negative_prompt = False                     # We will not use negative_prompt in this project\n",
    "model_version = \"sd_v1_4\"                   # Model version to use; [sd_v1_4: Stable Diffusion v1.4], [sd_xl_base_1_0: Stable Diffusion XL Base 1.0]\n",
    "epochs = 1                                  # Epochs you have ran with \"head_relevance_calculation.py\" or \"head_relevance_calculation_sdxl.py\"\n",
    "denominator = 5                             # (When using the --subset_running flag) denominator you have used in \"head_relevance_calculation.py\" or \"head_relevance_calculation_sdxl.py\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Load the (averaged) similarity scores if it exists, otherwise concatenate the subfiles.\n",
    "   In total_data,\n",
    "   - 0-th dim: the number of prompts (2100 for prompt=\"merged\", 1000 for prompt=\"imagenet\")\n",
    "   - 1-th dim: timestamp (50)\n",
    "   - 2-th dim: down_cross, mid_cross, up_cross, down_self, mid_self, up_self\n",
    "   - 3-th dim: number of attention layers\n",
    "   - 4-th dim: number of attention heads\n",
    "   - 5-th dim: number of concepts\"\"\"\n",
    "cwd = os.getcwd()\n",
    "file_names = []\n",
    "file_paths = []\n",
    "for epoch in range(1, epochs+1):\n",
    "    file_name = f\"{prompt}_{description}_ba_epoch_{epoch}_neg_prompt_{negative_prompt}_{model_version}.pt\"\n",
    "    file_path = os.path.join(\"results\", file_name)\n",
    "    file_names.append(file_name)\n",
    "    file_paths.append(file_path)\n",
    "save_file_name = f\"{prompt}_{description}_ba_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{model_version}.pt\"\n",
    "save_file_path = os.path.join(\"results\", save_file_name)\n",
    "\n",
    "if not os.path.exists(save_file_path):\n",
    "    total_data = []\n",
    "    for i in range(len(file_names)):\n",
    "        if os.path.exists(file_paths[i]):\n",
    "            print(\"The file already exists. Load it and extend.\")\n",
    "            if i == 0:\n",
    "                total_data = torch.load(file_paths[i])\n",
    "            else:\n",
    "                total_data.extend(torch.load(file_paths[i]))\n",
    "        else:\n",
    "            print(\"The file does not exist. Load the subfiles and concatenate them.\")\n",
    "            # Load the saved (averaged) similarity subfiles and concatenate them\n",
    "            for numerator in range(1, denominator + 1):\n",
    "                subfile_path = file_paths[i].replace(f\"_{model_version}.pt\", f\"_{numerator}_{denominator}_{model_version}.pt\")\n",
    "                total_data.extend(torch.load(subfile_path))\n",
    "\n",
    "    # Save the concatenated file\n",
    "    torch.save(total_data, save_file_path)\n",
    "else:\n",
    "    print(\"The file already exists. Load it and exit.\")\n",
    "    total_data = torch.load(save_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"final_result\", exist_ok=True)\n",
    "\"\"\"Move 0-th dim to the last-dim\"\"\"\n",
    "num_categories = len(total_data[0][0][\"down_cross\"][0][0])\n",
    "num_prompts = len(total_data)\n",
    "num_timestamps = len(total_data[0])\n",
    "permuted_data = []\n",
    "\n",
    "for _ in range(num_timestamps):\n",
    "    permuted_data.append({key: [[] for _ in range(len(total_data[0][0][key]))] for key in total_data[0][0].keys()})\n",
    "\n",
    "for i in range(num_prompts):\n",
    "    for j in range(num_timestamps):\n",
    "        for key in total_data[0][0].keys():\n",
    "            for l in range(len(total_data[i][j][key])):\n",
    "                permuted_data[j][key][l].append(total_data[i][j][key][l])\n",
    "\n",
    "\"\"\"Sum over generation prompts\"\"\"\n",
    "sum_values = []\n",
    "\n",
    "for _ in range(num_timestamps):\n",
    "    sum_values.append({key: [0 for _ in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()})\n",
    "\n",
    "for j in range(num_timestamps):\n",
    "    for key in total_data[0][0].keys():\n",
    "        for l in range(len(permuted_data[j][key])):\n",
    "            sum_values[j][key][l] = torch.stack(permuted_data[j][key][l]).sum(dim=0)\n",
    "\n",
    "\"\"\"Load concept names\"\"\"\n",
    "cwd = os.getcwd()\n",
    "description_file_name = description.replace(\"descriptions\", \"list.csv\")\n",
    "description_file_path = os.path.join(cwd, \"descriptions\", description_file_name)\n",
    "\n",
    "category_names = []\n",
    "with open(description_file_path, \"r\") as f:\n",
    "    reader = csv.reader(f)\n",
    "    for row in reader:\n",
    "        category_names.append(row)\n",
    "\n",
    "num_list = np.arange(1, 11).astype(str)\n",
    "category_names = [row for row in category_names if row != []]\n",
    "category_names = [row[0] for row in category_names if not any(num in row[0] for num in num_list)]\n",
    "category_names[0] = category_names[0][1:]\n",
    "category_names = np.array(category_names)\n",
    "\n",
    "# ------------------------------------------------------------------------------------ #\n",
    "\"\"\"Save the ranking of concepts for each CA head: This is for the reference purpose only\"\"\"\n",
    "sum_idx_rank = []\n",
    "header = [f\"Top {i}\"for i in range(1, len(category_names) + 1)]\n",
    "column_names = [\"[Name]\", \"[Layer]\", \"[Head]\", \"[Timestamp]\"] + header\n",
    "notable_heads = {}\n",
    "\n",
    "save_file_name = f\"output_ba_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.csv\"\n",
    "save_path = os.path.join(cwd, \"final_result\", save_file_name)\n",
    "\n",
    "for _ in range(num_timestamps):\n",
    "    sum_idx_rank.append({key: [torch.zeros(len(permuted_data[0][key][l][0]), num_categories, dtype=torch.int) for l in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()})\n",
    "\n",
    "with open(save_path, \"w\", newline='') as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow(column_names)\n",
    "    for key in total_data[0][0].keys():\n",
    "        for l in range(len(permuted_data[0][key])):\n",
    "            for h in range(permuted_data[0][key][l][0].shape[0]): \n",
    "                for j in range(num_timestamps):\n",
    "                    head_position = [\"\"] * 4\n",
    "                    count = [0] * len(category_names)\n",
    "                    sum_idx_rank[j][key][l][h] = sum_values[j][key][l][h].argsort(descending=True)\n",
    "                    head_position = [key] + [str(l)] + [str(h)] + [str(j)]\n",
    "                    category_rank = list(category_names[sum_idx_rank[j][key][l][h]])\n",
    "                    writer.writerow(head_position + category_rank)\n",
    "# ------------------------------------------------------------------------------------ #\n",
    "\"\"\"Sum over timestamps\"\"\"\n",
    "sum_values_over_timestamps = {key: [0 for _ in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()}\n",
    "\n",
    "for key in total_data[0][0].keys():\n",
    "    for l in range(len(sum_values_over_timestamps[key])):\n",
    "        sum_value = torch.stack([sum_values[j][key][l] for j in range(num_timestamps)]).sum(dim=0)\n",
    "        normalized_value = sum_value / sum_value.sum(dim=-1, keepdim=True)\n",
    "        sum_values_over_timestamps[key][l] = normalized_value\n",
    "\n",
    "\n",
    "\"\"\"Check the number of CA heads for each CA layer\"\"\"\n",
    "head_cnt = 0\n",
    "for place in sum_values_over_timestamps.keys():\n",
    "    if \"cross\" in place:\n",
    "        for l in range(len(sum_values_over_timestamps[place])):\n",
    "            print(f\"{place}: {l}-th layer, num_heads: {len(sum_values_over_timestamps[place][l])}\")\n",
    "            for h in range(len(sum_values_over_timestamps[place][l])):\n",
    "                head_cnt += 1\n",
    "        print()\n",
    "print(f\"total number of CA heads: {head_cnt}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Extract concept vectors\"\"\"\n",
    "head_index = []\n",
    "if model_version == \"sd_v1_4\":\n",
    "    category_vectors = np.zeros((len(category_names), head_cnt)) \n",
    "elif model_version == \"sd_xl_base_1_0\":\n",
    "    category_vectors = np.zeros((len(category_names), head_cnt))\n",
    "\n",
    "head_cnt = 0\n",
    "for place in sum_values_over_timestamps.keys():\n",
    "    if \"cross\" in place:\n",
    "        for l in range(len(sum_values_over_timestamps[place])):\n",
    "            for h in range(len(sum_values_over_timestamps[place][l])):\n",
    "                head_index.append(f\"{place.replace(\"_cross\", \"\")}, layer: {l}, head: {h}\")\n",
    "                category_vectors[:, head_cnt] = sum_values_over_timestamps[place][l][h]\n",
    "                head_cnt += 1\n",
    "head_index = np.array(head_index)\n",
    "\n",
    "category_vectors_tensor = torch.tensor(category_vectors, dtype=torch.float32)\n",
    "category_vectors_tensor = nnf.normalize(category_vectors_tensor, p=1, dim=1)\n",
    "category_vectors = category_vectors_tensor.numpy() * category_vectors.shape[1]\n",
    "\n",
    "np.save(f\"./final_result/category_vectors_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.npy\", category_vectors)\n",
    "np.save(f\"./final_result/head_index_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.npy\", head_index)\n",
    "# ------------------------------------------------------------------------------------ #\n",
    "\"\"\"Extract category vectors per timestep: This is for the reference purpose only.\n",
    "   This does not include sum over timesteps\"\"\"\n",
    "sum_values_all_timestamps = []\n",
    "for j in range(num_timestamps):\n",
    "    sum_values_tmp = {key: [0 for _ in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()}\n",
    "    for key in total_data[0][0].keys():\n",
    "        for l in range(len(sum_values_tmp[key])):\n",
    "            sum_value = sum_values[j][key][l]\n",
    "            normalized_value = sum_value / sum_value.sum(dim=-1, keepdim=True)\n",
    "            sum_values_tmp[key][l] = normalized_value\n",
    "    sum_values_all_timestamps.append(sum_values_tmp)\n",
    "\n",
    "head_index = []\n",
    "if model_version == \"sd_v1_4\":\n",
    "    category_vectors = np.zeros((num_timestamps, len(category_names), 128)) # 128 is the number of CA heads\n",
    "elif model_version == \"sd_xl_base_1_0\":\n",
    "    category_vectors = np.zeros((num_timestamps, len(category_names), 1300))\n",
    "for t in range(num_timestamps):\n",
    "    head_cnt = 0\n",
    "    for place in sum_values_all_timestamps[t].keys():\n",
    "        if \"cross\" in place:\n",
    "            for l in range(len(sum_values_all_timestamps[t][place])):\n",
    "                for h in range(len(sum_values_all_timestamps[t][place][l])):\n",
    "                    head_index.append(f\"{place.replace(\"_cross\", \"\")}, layer: {l}, head: {h}\")\n",
    "                    category_vectors[t, :, head_cnt] = sum_values_all_timestamps[t][place][l][h]\n",
    "                    head_cnt += 1\n",
    "head_index = np.array(head_index)\n",
    "category_vectors_tensor = torch.tensor(category_vectors, dtype=torch.float32)\n",
    "category_vectors_tensor = nnf.normalize(category_vectors_tensor, p=1, dim=2)\n",
    "category_vectors = category_vectors_tensor.numpy() * category_vectors.shape[2]\n",
    "np.save(f\"./final_result/category_vectors_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}_per_timesteps.npy\", category_vectors)\n",
    "np.save(f\"./final_result/head_index_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}_per_timesteps.npy\", head_index)\n",
    "# ------------------------------------------------------------------------------------ #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"[Save top_k head_lists for each concept]\n",
    "   These files will be used in our analysis of weakening each concept\"\"\"\n",
    "if model_version == \"sd_v1_4\":\n",
    "    num_heads = head_cnt # 128\n",
    "    head_iterate = list(range(1, num_heads + 1, 10)) + [num_heads] # For memory efficiency, we will save the top/bottom 1, 11, 21, ..., 128 heads for each visual concept\n",
    "elif model_version == \"sd_xl_base_1_0\":\n",
    "    num_heads = head_cnt # 1300\n",
    "    head_iterate = list(range(11, num_heads + 1, 100)) + [num_heads] # For memoery efficiency, we will save the top/bottom 11, 111, 211, ..., 1300 heads for each visual concept\n",
    "\n",
    "for top_k in head_iterate:\n",
    "    save_file_name = f\"head_roles_ba_epoch_1_to_{epochs}_top_{top_k}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.csv\"\n",
    "    save_path = os.path.join(\"final_result\", save_file_name)\n",
    "    with open(save_path, 'w', newline=\"\") as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow([\"Category\", \"Head position\"])\n",
    "        head_index = []\n",
    "        if model_version == \"sd_v1_4\":\n",
    "            ranking_list = np.zeros((len(category_names), 128)) # 128 is the number of heads\n",
    "        elif model_version == \"sd_xl_base_1_0\":\n",
    "            ranking_list = np.zeros((len(category_names), 1300))\n",
    "        \n",
    "        head_cnt = 0\n",
    "        for place in sum_values_over_timestamps.keys():\n",
    "            if \"cross\" in place:\n",
    "                for l in range(len(sum_values_over_timestamps[place])):\n",
    "                    for h in range(len(sum_values_over_timestamps[place][l])):\n",
    "                        head_index.append(f\"{place.replace(\"_cross\", \"\")}, layer: {l}, head: {h}\")\n",
    "                        ranking_list[:, head_cnt] = sum_values_over_timestamps[place][l][h].numpy()\n",
    "                        head_cnt += 1\n",
    "        head_index = np.array(head_index)\n",
    "        for i in range(len(category_names)):\n",
    "            top_k_indices = np.argsort(ranking_list[i])[::-1][:top_k]\n",
    "            head_positions = head_index[top_k_indices]\n",
    "            head_positions_str = \"['\"+\"' '\".join(head_positions)+\"']\"  # Join the list into a single string\n",
    "            writer.writerow([category_names[i], head_positions_str])\n",
    "\n",
    "for bottom_k in head_iterate:\n",
    "    save_file_name = f\"head_roles_ba_epoch_1_to_{epochs}_bottom_{bottom_k}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.csv\"\n",
    "    save_path = os.path.join(\"final_result\", save_file_name)\n",
    "    with open(save_path, 'w', newline=\"\") as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow([\"Category\", \"Head position\"])\n",
    "        head_index = []\n",
    "        if model_version == \"sd_v1_4\":\n",
    "            ranking_list = np.zeros((len(category_names), 128)) # 128 is the number of heads\n",
    "        elif model_version == \"sd_xl_base_1_0\":\n",
    "            ranking_list = np.zeros((len(category_names), 1300))\n",
    "        \n",
    "        head_cnt = 0\n",
    "        for place in sum_values_over_timestamps.keys():\n",
    "            if \"cross\" in place:\n",
    "                for l in range(len(sum_values_over_timestamps[place])):\n",
    "                    for h in range(len(sum_values_over_timestamps[place][l])):\n",
    "                        head_index.append(f\"{place.replace(\"_cross\", \"\")}, layer: {l}, head: {h}\")\n",
    "                        ranking_list[:, head_cnt] = sum_values_over_timestamps[place][l][h].numpy()\n",
    "                        head_cnt += 1\n",
    "        head_index = np.array(head_index)\n",
    "        for i in range(len(category_names)):\n",
    "            bottom_k_indices = np.argsort(ranking_list[i])[:bottom_k]\n",
    "            head_positions = head_index[bottom_k_indices]\n",
    "            head_positions_str = \"['\"+\"' '\".join(head_positions)+\"']\"  # Join the list into a single string\n",
    "            writer.writerow([category_names[i], head_positions_str])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_version == \"sd_v1_4\":\n",
    "    np.save(f\"./final_result/{description.replace(\"_text_descriptions\", \"\")}_ranking_list.npy\", ranking_list)\n",
    "else:\n",
    "    np.save(f\"./final_result/{description.replace(\"_text_descriptions\", \"\")}_ranking_list_{model_version}.npy\", ranking_list)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diffuser",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
