{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "library(ggplot2)\n",
    "library(tidyverse)\n",
    "library(gfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# set an experiment name\n",
    "experiment_name <- \"gpt-j-6B_counterfact_k0_sd1_tracing_sweep_n2000\"\n",
    "DATA_DIR = \"\" # your data directory here\n",
    "data_path <- sprintf(\"%s/%s.csv\", DATA_DIR, experiment_name)\n",
    "\n",
    "orig_data <- read_csv(file)\n",
    "head(orig_data)\n",
    "names(orig_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Globals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "theme = theme(axis.ticks = element_blank(),\n",
    "        axis.text = element_text(size=15, color='black'),\n",
    "        axis.line.x = element_line(colour = 'black', size = .6),\n",
    "        axis.line.y = element_line(colour = 'black', size = .6),\n",
    "        panel.background = element_blank(),\n",
    "        panel.border = element_blank(),\n",
    "        panel.grid = element_line(colour = '#DFDFDF', size = .2),\n",
    "        text = element_text(size=18, family=\"serif\"),\n",
    "        axis.title.x = element_text(size = 18),\n",
    "        axis.title.y = element_text(size = 18),\n",
    "        plot.title = element_text(size = 20, hjust=0.5),\n",
    "        legend.text = element_text(size=16),\n",
    "        legend.box.background = element_blank(),\n",
    "        legend.position = \"right\",\n",
    "        panel.spacing=unit(1.5,\"lines\"),\n",
    "        )\n",
    "\n",
    "cbp1 <- c(\"#E69F00\", \"#56B4E9\", \"#009E73\",\n",
    "          \"#0072B2\", \"#D55E00\", \"#999999\", \"#F0E442\",  \"#CC79A7\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Postprocess tracing data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "options(repr.plot.width=8, repr.plot.height=6)\n",
    "# add a few variable to data\n",
    "orig_data <- orig_data %>%\n",
    "  mutate(\n",
    "    last_subj_token = (subj_end_idx-1 == token_idx),\n",
    "    restore_effect = restore_prob - corrupted_pred_prob,\n",
    "    module = ifelse(module=='None', 'all', module),\n",
    "    corruption_effect = corrupted_pred_prob - orig_pred_prob,\n",
    "    fraction_restored = pmax(0, restore_effect / abs(corruption_effect)) # bound this below by 0\n",
    "  )\n",
    "# get per-data-point statistics\n",
    "per_data_point <- orig_data %>%\n",
    "  group_by(input_id, module, trace_window_size) %>% \n",
    "  summarize(\n",
    "    experiment_name=min(experiment_name),\n",
    "    task=min(task),\n",
    "    split=min(split),\n",
    "    orig_pred_prob=min(orig_pred_prob), # this min does not do anything\n",
    "    corrupted_pred_prob=min(corrupted_pred_prob), # this min does not do anything\n",
    "    corruption_effect = min(corruption_effect), # this min does not do anything\n",
    "    max_effect = max(restore_effect),\n",
    "    mean_effect = mean(restore_effect),\n",
    "    seq_len = max(token_idx)+1,\n",
    "    max_fraction_restored = max(fraction_restored),\n",
    "    ) \n",
    "# add extra variables that rely on per-data-point statistics\n",
    "data <- left_join(\n",
    "  orig_data,\n",
    "  per_data_point %>% \n",
    "    select(input_id, module, trace_window_size, max_effect, max_fraction_restored),\n",
    "  by=c('input_id', 'module', 'trace_window_size')\n",
    ")\n",
    "data <- data %>%\n",
    "  mutate(\n",
    "    last_seq_token = (token_idx == seq_len-1),\n",
    "    is_max_effect = (restore_effect >= max_effect),\n",
    "  )\n",
    "\n",
    "# turn trace_window into leveled factor\n",
    "data <- data %>%\n",
    "  mutate(trace_window_size_str = sprintf(\"Tracing Window Size: %s\", trace_window_size),\n",
    "         trace_window_size_str = factor(trace_window_size_str, \n",
    "         levels = c(\"Tracing Window Size: 1\", \"Tracing Window Size: 3\", \"Tracing Window Size: 5\", \"Tracing Window Size: 10\"))\n",
    "  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Final plots 1: tracing effects by window size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "options(repr.plot.width=13, repr.plot.height=6)\n",
    "\n",
    "xticks <- seq(0, 28, by=4)\n",
    "xticks[1] <- 1\n",
    "grid_ymax = .3\n",
    "is_correct_filter <- 1\n",
    "min_orig_prob <- 0\n",
    "min_restoration_effect <- 0\n",
    "MODULE = 'mlp'\n",
    "\n",
    "TITLE = expression(\"Causal Tracing shows larger effects when multiple layers are denoised\")\n",
    "# TITLE = expression(paste(\"Earlier layers show the strongest tracing effects \", italic(\"on average\")))\n",
    "(avg_plot <- data %>% \n",
    "  filter(module==MODULE) %>%\n",
    "  filter(orig_pred_prob > min_orig_prob) %>%\n",
    "  filter(restore_effect > min_restoration_effect) %>%\n",
    "  filter(is_correct >= is_correct_filter) %>% \n",
    "  group_by(input_id, layer_idx, trace_window_size_str) %>%\n",
    "  summarise(max_effect = max(fraction_restored)) %>%\n",
    "  group_by(layer_idx, trace_window_size_str) %>%\n",
    "  summarise(mean_effect = mean(max_effect)) %>%\n",
    "  ggplot(aes(layer_idx, mean_effect)) + \n",
    "  geom_point() + \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"Layer in GPT-J\") + \n",
    "  ylab(\"Denoising Effect\") + \n",
    "  theme + \n",
    "  annotate(\"rect\", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = grid_ymax,\n",
    "         alpha = .7, fill = \"#FF0909\") + \n",
    "  annotate(\"rect\", xmin = 3, xmax = 8, ymin = 0, ymax = grid_ymax,\n",
    "         alpha = .2, fill = \"#2009FF\") + \n",
    "  annotation_custom(grid.text(\"ROME Edit Layer \", x=0.694,  y=0.89, gp=gpar(col = \"#FF0909\", fontsize=14, fontfamily='serif'))) +   \n",
    "  annotation_custom(grid.text(label=\"MEMIT Edit Layers\", check.overlap = TRUE, x=0.71,  y=0.815, gp=gpar(col = \"#2009FF\", fontsize=14, fontfamily='serif'))) + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5)\n",
    "  ) + \n",
    "  facet_wrap(~trace_window_size_str,  nrow=1)\n",
    ")\n",
    "\n",
    "# TITLE = sprintf(\"Causal Tracing often shows information is localized in mid-to-late layers\")\n",
    "TITLE = sprintf(\"Causal Tracing peak distribution shifts outward with lower window size\")\n",
    "(distr_plot <- data %>% \n",
    "  filter(module==MODULE) %>%\n",
    "  filter(orig_pred_prob > min_orig_prob) %>%\n",
    "  filter(restore_effect > min_restoration_effect) %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(is_max_effect == 1) %>%\n",
    "  ggplot(aes(layer_idx)) + \n",
    "  geom_histogram(binwidth = 1, \n",
    "                 size=.1) +  \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"Layer in GPT-J where Causal Tracing effects peak\") + \n",
    "  ylab(\"Count\") + \n",
    "  theme + \n",
    "  annotate(\"rect\", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = 200,\n",
    "         alpha = .7, fill = \"#FF0909\") + \n",
    "  annotate(\"rect\", xmin = 3, xmax = 8, ymin = 0, ymax = 200,\n",
    "         alpha = .2, fill = \"#2009FF\") + \n",
    "  annotation_custom(grid.text(\"ROME Edit Layer \", x=0.694,  y=0.89, gp=gpar(col = \"#FF0909\", fontsize=14, fontfamily='serif'))) +   \n",
    "  annotation_custom(grid.text(label=\"MEMIT Edit Layers\", check.overlap = TRUE, x=0.71,  y=0.815, gp=gpar(col = \"#2009FF\", fontsize=14, fontfamily='serif'))) + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5)\n",
    "  ) + \n",
    "  facet_wrap(~trace_window_size_str,  nrow=1)\n",
    ")\n",
    "\n",
    "options(repr.plot.width=8, repr.plot.height=6)\n",
    "ggsave('avg_plot.pdf', avg_plot, width=16, height=4, device=cairo_pdf)\n",
    "colab::download_file('avg_plot.pdf') \n",
    "ggsave('distr_plot.pdf', distr_plot, width=16, height=4, device=cairo_pdf)\n",
    "colab::download_file('distr_plot.pdf') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final plots 2: Tracing effects with overlay for ROME/MEMIT choices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "options(repr.plot.width=8, repr.plot.height=6)\n",
    "\n",
    "TRACE_WINDOW_SIZE = 5\n",
    "xticks <- seq(0, 28, by=4)\n",
    "xticks[1] <- 1\n",
    "\n",
    "grid_ymax = .06\n",
    "TITLE = expression(paste(\"Causal Tracing effects by layer \", italic(\"on average across data\")))\n",
    "# TITLE = expression(paste(\"Causal Tracing effects are largest in earlier layers \", italic(\"on average across data\")))\n",
    "(avg_plot <- data %>% \n",
    "  filter(module==MODULE) %>%\n",
    "  filter(orig_pred_prob > min_orig_prob) %>%\n",
    "  filter(restore_effect > min_restoration_effect) %>%\n",
    "  filter(is_correct >= is_correct_filter) %>% \n",
    "  filter(trace_window_size == TRACE_WINDOW_SIZE) %>%\n",
    "  # filter to last subj token positions for comparison with ROME design choices\n",
    "  select(input_id, last_subj_token, layer_idx, trace_window_size_str, restore_effect) %>%\n",
    "  filter(last_subj_token == TRUE) %>%\n",
    "  # group_by(input_id, layer_idx, trace_window_size_str) %>%\n",
    "  # summarise(max_effect = max(restore_effect)) %>%\n",
    "  group_by(layer_idx, trace_window_size_str) %>%\n",
    "  # summarise(mean_effect = mean(max_effect)) %>%\n",
    "  summarise(mean_effect = mean(restore_effect)) %>%\n",
    "  ggplot(aes(layer_idx, mean_effect)) + \n",
    "  geom_point() + \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"Layer in GPT-J\") + \n",
    "  ylab(\"Denoising Effect\") + \n",
    "  annotate(\"rect\", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = grid_ymax,\n",
    "         alpha = .7, fill = \"#FF0909\") + \n",
    "  annotate(\"rect\", xmin = 3, xmax = 8, ymin = 0, ymax = grid_ymax,\n",
    "         alpha = .2, fill = \"#2009FF\") + \n",
    "  annotation_custom(grid.text(\"ROME Edit Layer \", x=0.802,  y=0.88, gp=gpar(col = \"#FF0909\", fontsize=16, fontfamily='serif'))) +   \n",
    "  annotation_custom(grid.text(label=\"MEMIT Edit Layers\", check.overlap = TRUE, x=0.81,  y=0.82, gp=gpar(col = \"#2009FF\", fontsize=16, fontfamily='serif'))) + \n",
    "  theme + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5)\n",
    "  )\n",
    ")\n",
    "\n",
    "TITLE = \"How often does Causal Tracing peak in each layer?\"\n",
    "# TITLE = sprintf(\"Peak Causal Tracing effects often lie outside layers chosen for editing\")\n",
    "YMAX = 200\n",
    "(distr_plot <- data %>% \n",
    "  filter(module==MODULE) %>%\n",
    "  filter(orig_pred_prob > min_orig_prob) %>%\n",
    "  filter(restore_effect > min_restoration_effect) %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(is_max_effect) %>%\n",
    "  filter(trace_window_size == TRACE_WINDOW_SIZE) %>%\n",
    "  ggplot(aes(layer_idx)) + \n",
    "  geom_histogram(binwidth = 1, \n",
    "                 size=.1) +  \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"Layer in GPT-J where Causal Tracing effects peak\") + \n",
    "  ylab(\"Num. Points\") + \n",
    "  annotate(\"rect\", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = YMAX,\n",
    "         alpha = 1, fill = \"#FF0909\") + \n",
    "  annotate(\"rect\", xmin = 3, xmax = 8, ymin = 0, ymax = YMAX,\n",
    "         alpha = .2, fill = \"#2009FF\") + \n",
    "  annotation_custom(grid.text(\"ROME Edit Layer \", x=0.788,  y=0.876, gp=gpar(col = \"#FF0909\", fontsize=20, fontfamily='serif'))) +   \n",
    "  annotation_custom(grid.text(label=\"MEMIT Edit Layers\", check.overlap = TRUE, x=0.801,  y=0.810, gp=gpar(col = \"#2009FF\", fontsize=20, fontfamily='serif'))) + \n",
    "  theme + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  theme(axis.text = element_text(size=20, color='black'),\n",
    "        axis.title.x = element_text(size=24, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=24, color='black', angle=90, vjust=.5, hjust=.5),\n",
    "        plot.title = element_text(size=25, color='black', angle=0, vjust=.5, hjust=.5)\n",
    "  )\n",
    ")\n",
    "\n",
    "# print fraction of points in the 4-9 (3-8 starting at 0) range for ws 1\n",
    "data %>% \n",
    "  filter(module==MODULE) %>%\n",
    "  filter(trace_window_size == TRACE_WINDOW_SIZE) %>%\n",
    "  filter(orig_pred_prob > min_orig_prob) %>%\n",
    "  filter(restore_effect > min_restoration_effect) %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(is_max_effect == 1) %>%\n",
    "  summarise(inside = layer_idx %in% c(3,4,5,6,7,8)) %>%\n",
    "  pull(inside) %>% table\n",
    "\n",
    "ggsave('avg_plot.pdf', avg_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('avg_plot.pdf') \n",
    "ggsave('distr_plot.pdf', distr_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('distr_plot.pdf') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Join tracing and editing results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "tracing_data <- data %>%\n",
    "  mutate(case_id = input_id,\n",
    "         module == 'mlp')\n",
    "\n",
    "# set an individual experiment to load\n",
    "\n",
    "# FT fact forcing\n",
    "# experiment_name <- \"gpt-j-6B_FT_outputs_cf_editing_sweep_ws-_5__layer-all_fact-forcing_n2000\"\n",
    "\n",
    "# ROME error injection\n",
    "experiment_name <- \"gpt-j-6B_ROME_outputs_cf_editing_sweep_ws-_1__layer-all_n2000\"\n",
    "\n",
    "editing_path <- sprintf(\"%s/%s.csv\", DATA_DIR, experiment_name)\n",
    "editing_data <- read_csv(file)\n",
    "editing_data %>% select(case_id) %>% unique() %>% nrow()\n",
    "\n",
    "TRACE_WINDOW_SIZE = 5\n",
    "data <- data %>%\n",
    "  mutate(case_id = input_id)\n",
    "\n",
    "# add essence_ppl_diff transformations to editing_data\n",
    "MAX_ppl_diff <- 5\n",
    "safe_inverse <- function(x){\n",
    "  if (x==0){\n",
    "    return(2e16)\n",
    "  } else{\n",
    "    return(1/x)\n",
    "  }\n",
    "}\n",
    "editing_data <- editing_data %>%\n",
    "  mutate(essence_ppl_diff_bounded = pmax(0, pmin(essence_ppl_diff, MAX_ppl_diff)),\n",
    "         essence_diff_normed = 1 - essence_ppl_diff_bounded / MAX_ppl_diff,\n",
    "         target_score_v2 = 4/(safe_inverse(rewrite_score) + safe_inverse(paraphrase_score) + safe_inverse(neighborhood_score) + safe_inverse(essence_diff_normed)),\n",
    "         target_score_mean = (rewrite_score + paraphrase_score + neighborhood_score) / 3,\n",
    "         target_score_mean_v2 = (rewrite_score + paraphrase_score + neighborhood_score + essence_diff_normed) / 4,\n",
    "         )\n",
    "\n",
    "# MAKE FIRST STYLE OF JOINED DATA. \n",
    "# one record per edit per datapoint. includes all tracing variables pertaining to the MAX effect per point\n",
    "combined_data <- left_join(tracing_data, editing_data, join_by='case_id')\n",
    "combined_data_max_MLP <- combined_data %>%\n",
    "  filter(is_max_effect == TRUE, module=='mlp') %>%\n",
    "  mutate(layer_discrepancy = edit_central_layer - layer_idx,\n",
    "         max_tracing_layer = layer_idx)\n",
    "combined_data_max_MLP <- combined_data_max_MLP %>% select(-layer_idx)\n",
    "combined_data_max_MLP %>% select(case_id) %>% unique() %>% nrow()\n",
    "\n",
    "# SECOND STYLE OF JOINED DATA.\n",
    "# one record per edit per datapoint.\n",
    "\n",
    "# add max-token and subj-token effects for each point+layer to editing data\n",
    "max_per_layer_per_record <- tracing_data %>% \n",
    "  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%\n",
    "  group_by(case_id, layer_idx) %>%\n",
    "  summarise(\n",
    "    max_token_effect=max(restore_effect),\n",
    "    max_fraction_restored=max(fraction_restored),\n",
    "    # variables below are constant per point, so mean does nothing. need them for later filtering \n",
    "    is_correct=mean(is_correct),\n",
    "    corruption_effect=mean(corruption_effect),\n",
    "    orig_pred_prob=mean(orig_pred_prob),\n",
    "    ) \n",
    "subj_effect_per_layer_per_record <- tracing_data %>% \n",
    "  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%\n",
    "  group_by(case_id, layer_idx) %>%\n",
    "  filter(last_subj_token==TRUE) %>%\n",
    "  summarise(subj_effect=restore_effect,\n",
    "            subj_end_idx=subj_end_idx,\n",
    "            seq_len=seq_len,\n",
    "            subj_effect_fraction = restore_effect / abs(corruption_effect))\n",
    "per_edit_data <- editing_data %>%\n",
    "  mutate(layer_idx = edit_central_layer) %>% # need to add layer_idx for matching tracing by layer_idx\n",
    "  left_join(\n",
    "    max_per_layer_per_record,\n",
    "    join_by=c(\"case_id\", \"layer_idx\")\n",
    "  ) %>%\n",
    "  left_join(\n",
    "    subj_effect_per_layer_per_record,\n",
    "    join_by=c(\"case_id\", \"layer_idx\")\n",
    "  )\n",
    "# add indicator for if tracing effect is in the 99th percentile of tracing effects\n",
    "effect_cutoff <- quantile(data %>% filter(module == 'mlp') %>% pull(restore_effect), .95)\n",
    "print(effect_cutoff)\n",
    "# add some more variables\n",
    "# - discrete restoration effect variable\n",
    "# - subj position\n",
    "per_edit_data <- per_edit_data %>%\n",
    "  mutate(large_tracing_effect = max_token_effect > effect_cutoff,\n",
    "         orig_pred_prob_disc = ifelse(orig_pred_prob < .05, \"<.05\", \n",
    "                    ifelse(orig_pred_prob < .1, \".05-.1\", \n",
    "                    ifelse(orig_pred_prob < .15, \".1-.15\", \n",
    "                    ifelse(orig_pred_prob < .2, \".15-.2\", \n",
    "                    ifelse(orig_pred_prob < .25, \".2-.25\", '>.25'))))),\n",
    "        subj_position = subj_end_idx / seq_len\n",
    "  )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FINAL PLOT: score vs. restoration effect by layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "is_correct_filter = 1\n",
    "# edit_layers = c(0,12,16,20)\n",
    "# edit_layers = c(0,4,8,12)\n",
    "# edit_layers = c(0, 4, 5, 8)\n",
    "# edit_layers = c(0, 12, 16, 20)\n",
    "# edit_layers = c(8, 12, 16, 20)\n",
    "# edit_layers = c(0, 8, 16, 24)\n",
    "# edit_layers = c(0, 4, 8, 12)\n",
    "# options(repr.plot.width=13, repr.plot.height=4)\n",
    "# NROW=1\n",
    "\n",
    "edit_layers = c(0, 4, 8, 12, 16, 20, 24, 27)\n",
    "options(repr.plot.width=13, repr.plot.height=8)\n",
    "NROW=2\n",
    "\n",
    "x_ub = 1\n",
    "min_orig_prob = 0\n",
    "layer_levels <- c()\n",
    "for (i in edit_layers){\n",
    "  layer_levels <- c(layer_levels, sprintf(\"Layer %s\", i+1))\n",
    "}\n",
    "\n",
    "x = \"max_fraction_restored\"\n",
    "# x = \"max_token_effect\"\n",
    "# x = \"subj_effect_fraction\"\n",
    "# x = \"subj_effect\"\n",
    "point_alpha = .1\n",
    "CI_alpha = .15\n",
    "CI_fill = 'orange'\n",
    "show_se = TRUE\n",
    "smooth_method = 'lm'\n",
    "# smooth_method = 'loess'\n",
    "outcome = \"score\"\n",
    "ovr_name = \"target_score_mean\"\n",
    "line_size=1.3\n",
    "\n",
    "n_unique_points <- per_edit_data %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(orig_pred_prob >= min_orig_prob) %>%\n",
    "  filter(get(x) < x_ub) %>% \n",
    "  pull(case_id) %>% \n",
    "  unique() %>% \n",
    "  length()\n",
    "sprintf(\"Plotting with %s points\", n_unique_points)\n",
    "\n",
    "qs <- quantile(per_edit_data$rewrite_score, c(.4, 1))\n",
    "TITLE=\"Rewrite Score by Tracing Effect (Grouped by Edit Layer)\"\n",
    "# TITLE=\"Fact Forcing Rewrite Score by Tracing Effect (Grouped by Edit Layer)\"\n",
    "# TITLE=\"ROME Rewrite Score by Tracing Effect (Error Injection)\"\n",
    "# TITLE=\"ROME Rewrite Score by Last Subject Token Tracing Effect (Error Injection)\"\n",
    "(rewrite_plot <- per_edit_data %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(orig_pred_prob >= min_orig_prob) %>%\n",
    "  filter(edit_central_layer %in% edit_layers) %>%\n",
    "  mutate(edit_central_layer = sprintf(\"Layer %s\", edit_central_layer+1),\n",
    "           edit_central_layer = factor(edit_central_layer, levels=layer_levels)\n",
    "         ) %>%\n",
    "  ggplot(aes_string(x, sprintf(\"rewrite_%s\", outcome))) + \n",
    "  geom_point(alpha=point_alpha) +\n",
    "  geom_smooth(method=smooth_method, se=show_se, color='orange', alpha=CI_alpha, fill=CI_fill, size=line_size) + \n",
    "  geom_abline(slope=1, intercept=0, color='#F94343', linetype=2, size=.8, inherit.aes=FALSE) + \n",
    "  # geom_segment(aes(x=0, xend=1, y=0, yend=1), color='Red', linetype=1, size=.25, inherit.aes = FALSE) + \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"Tracing Effect\") + \n",
    "  # xlab(\"Tracing Effect at Last Subject Token\") + \n",
    "  ylab(\"Rewrite Score\") + \n",
    "  theme + \n",
    "  coord_cartesian(xlim=c(0,x_ub), ylim=c(0, 1)) + \n",
    "  theme(axis.title.y = element_text(size=20, color='black', angle=90, vjust=1.5, hjust=.5),\n",
    "        strip.text.x = element_text(size=16, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        axis.text = element_text(size=16, color='black'),\n",
    "        axis.title.x = element_text(size=20, color='black', angle=0, vjust=0, hjust=.5),\n",
    "        plot.title = element_text(size=22, color='black', angle=0, vjust=.5, hjust=.5)      \n",
    "  ) + \n",
    "  facet_wrap(~edit_central_layer, nrow=NROW))\n",
    "  \n",
    "ggsave('rewrite_plot.pdf', rewrite_plot, width=13, height=4*NROW, dpi=600, device=cairo_pdf)\n",
    "colab::download_file('rewrite_plot.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "options(repr.plot.width=8, repr.plot.height=6)\n",
    "TITLE=\"ROME Rewrite Score by Tracing Effect at Layer 6\"\n",
    "show_se=TRUE\n",
    "line_size=1.4\n",
    "(rewrite_plot <- per_edit_data %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(orig_pred_prob >= min_orig_prob) %>%\n",
    "  filter(edit_central_layer == 5) %>%\n",
    "  ggplot(aes_string(x, sprintf(\"rewrite_%s\", outcome))) + \n",
    "  geom_point(alpha=point_alpha) +\n",
    "  geom_abline(slope=1, intercept=0, color='#F94343', linetype=2, size=.8) + \n",
    "  geom_smooth(method=smooth_method, se=show_se, color='orange', alpha=CI_alpha, fill=CI_fill, size=line_size) + \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"Tracing Effect (Fraction Restored)\") + \n",
    "  ylab(\"Rewrite Score\") + \n",
    "  theme + \n",
    "  coord_cartesian(xlim=c(0,x_ub), ylim=c(0, 1)) +\n",
    "  theme(axis.title.y = element_text(size=21, color='black', angle=90, vjust=1.5, hjust=.5),\n",
    "        axis.text = element_text(size=18, color='black'),\n",
    "        axis.title.x = element_text(size=21, color='black', angle=0, vjust=0, hjust=.5),\n",
    "        plot.title = element_text(size=25, color='black', angle=0, vjust=.5, hjust=.5)      \n",
    "  )\n",
    ")\n",
    "ggsave('rewrite_plot.pdf', rewrite_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('rewrite_plot.pdf')\n",
    "\n",
    "\n",
    "x <- per_edit_data %>%\n",
    "  filter(is_correct >= is_correct_filter) %>%\n",
    "  filter(orig_pred_prob >= min_orig_prob) %>%\n",
    "  filter(edit_central_layer == 5) %>%\n",
    "  select(rewrite_score, max_fraction_restored, subj_effect_fraction)\n",
    "\n",
    "cor.test(x$rewrite_score, x$max_fraction_restored)\n",
    "cor.test(x$rewrite_score, x$subj_effect_fraction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ALL experiments: Tracing vs. editing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "conditions <- list()\n",
    "conditions[[1]] <- c(\"FT\", \"1\")\n",
    "conditions[[2]] <- c(\"FT\", \"5\")\n",
    "conditions[[3]] <- c(\"ROME\", \"1\")\n",
    "conditions[[4]] <- c(\"MEMIT\", \"5\")\n",
    "objectives <- c(\"Falsehood Injection\", \"Tracing Reversal\", \"Fact Erasure\", \"Fact Forcing\", \"Fact Amplification\")\n",
    "\n",
    "safe_load_data <- function(experiment_name){  \n",
    "  possibleError <- tryCatch(\n",
    "    read_csv(gfile::GFile(sprintf(\"/cns/mf-d/home/brain-frameworks/cloudsync/belief-localization-xgcp/output/%s.csv\", experiment_name))),\n",
    "    error=function(e) e\n",
    "  )\n",
    "  if(!inherits(possibleError, \"error\")){\n",
    "    return(possibleError)\n",
    "  } else{\n",
    "    return(data.frame())\n",
    "  }\n",
    "}\n",
    "\n",
    "made_first_df <- FALSE\n",
    "for (condition in conditions){\n",
    "  method <- condition[1]\n",
    "  window_size <- condition[2]\n",
    "  for (objective in objectives){\n",
    "    if (objective == \"Falsehood Injection\"){\n",
    "      obj_tag <- \"\"\n",
    "    }\n",
    "    if (objective == \"Tracing Reversal\"){\n",
    "      obj_tag <- \"_trace-reverse\"\n",
    "    }\n",
    "    if (objective == \"Fact Erasure\"){\n",
    "      obj_tag <- \"_fact-erasure\"\n",
    "    }\n",
    "    if (objective == \"Fact Forcing\"){\n",
    "      obj_tag <- \"_fact-forcing\"\n",
    "    }\n",
    "    if (objective == \"Fact Amplification\"){\n",
    "      obj_tag <- \"_fact-amplification\"\n",
    "    }\n",
    "    experiment_name <- sprintf(\"gpt-j-6B_%s_outputs_cf_editing_sweep_ws-_%s__layer-all%s_n2000\", method, window_size, obj_tag)\n",
    "    print(sprintf(\"Trying to load %s\", experiment_name))\n",
    "    if (made_first_df){\n",
    "      editing_data <- safe_load_data(experiment_name)\n",
    "      if (nrow(editing_data) == 0) next\n",
    "      editing_data$objective = objective\n",
    "      print(\"Loaded!\")\n",
    "      running_editing_data <- bind_rows(running_editing_data, editing_data)\n",
    "    } else {\n",
    "      running_editing_data <- safe_load_data(experiment_name)\n",
    "      if (nrow(running_editing_data) == 0) next\n",
    "      running_editing_data$objective = objective\n",
    "      print(\"Loaded!\")\n",
    "      made_first_df <- TRUE\n",
    "    }\n",
    "    }\n",
    "}\n",
    "editing_data <- running_editing_data %>%\n",
    "  filter(edit_central_layer >= 0) %>%\n",
    "  select(case_id, rewrite_score, paraphrase_score, neighborhood_score, essence_ppl_diff, target_score, edit_method, edit_central_layer, objective, edit_window_size) %>%\n",
    "  mutate(objective = as.factor(objective, levels=objectives))\n",
    "\n",
    "TRACE_WINDOW_SIZE = 5\n",
    "tracing_data <- data %>%\n",
    "  filter(trace_window_size==TRACE_WINDOW_SIZE) %>%\n",
    "  mutate(case_id = input_id) %>%\n",
    "  select(case_id, trace_window_size, layer_idx, token_idx, module, orig_pred_prob, corrupted_pred_prob, corruption_effect, seq_len, is_correct, last_seq_token, is_subj_token, last_subj_token, restore_effect, corruption_effect, fraction_restored, max_effect, max_fraction_restored, is_max_effect, trace_window_size_str)\n",
    "print(\"Tracing # per trace window size\")\n",
    "table(tracing_data$trace_window_size)\n",
    "\n",
    "data <- data %>%\n",
    "  mutate(case_id = input_id)\n",
    "\n",
    "# add essence_ppl_diff transformations to editing_data\n",
    "MAX_ppl_diff <- 5\n",
    "safe_inverse <- function(x){\n",
    "  if (x==0){\n",
    "    return(2e16)\n",
    "  } else{\n",
    "    return(1/x)\n",
    "  }\n",
    "}\n",
    "# add cols and is_correct variable\n",
    "editing_data <- editing_data %>%\n",
    "  mutate(essence_ppl_diff_bounded = pmax(0, pmin(essence_ppl_diff, MAX_ppl_diff)),\n",
    "         essence_diff_normed = 1 - essence_ppl_diff_bounded / MAX_ppl_diff,\n",
    "         target_score_v2 = 4/(safe_inverse(rewrite_score) + safe_inverse(paraphrase_score) + safe_inverse(neighborhood_score) + safe_inverse(essence_diff_normed)),\n",
    "         target_score_mean = (rewrite_score + paraphrase_score + neighborhood_score) / 3,\n",
    "         target_score_mean_v2 = (rewrite_score + paraphrase_score + neighborhood_score + essence_diff_normed) / 4,\n",
    "         ) %>%\n",
    "  left_join(tracing_data %>% select(case_id, is_correct) %>% unique())\n",
    "\n",
    "# MAKE FIRST STYLE OF JOINED DATA. \n",
    "# one record per edit per datapoint. includes all tracing variables pertaining to the MAX effect per point\n",
    "combined_data <- left_join(tracing_data, editing_data, join_by='case_id')\n",
    "combined_data_max_MLP <- combined_data %>%\n",
    "  filter(is_max_effect == TRUE, module=='mlp') %>%\n",
    "  mutate(layer_discrepancy = edit_central_layer - layer_idx,\n",
    "         max_tracing_layer = layer_idx)\n",
    "combined_data_max_MLP <- combined_data_max_MLP %>% select(-layer_idx)\n",
    "\n",
    "# SECOND STYLE OF JOINED DATA.\n",
    "# one record per edit per datapoint.\n",
    "\n",
    "# add max-token and subj-token effects for each point+layer to editing data\n",
    "max_per_layer_per_record <- tracing_data %>% \n",
    "  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%\n",
    "  group_by(case_id, layer_idx) %>%\n",
    "  summarise(\n",
    "    max_token_effect=max(restore_effect),\n",
    "    max_fraction_restored=max(fraction_restored),\n",
    "    # variables below are constant per point, so mean does nothing. need them for later filtering \n",
    "    corruption_effect=mean(corruption_effect),\n",
    "    orig_pred_prob=mean(orig_pred_prob),\n",
    "    ) \n",
    "subj_effect_per_layer_per_record <- tracing_data %>% \n",
    "  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%\n",
    "  group_by(case_id, layer_idx) %>%\n",
    "  filter(last_subj_token==TRUE) %>%\n",
    "  summarise(subj_effect=restore_effect,\n",
    "            subj_effect_fraction = restore_effect / abs(corruption_effect))\n",
    "per_edit_data <- editing_data %>%\n",
    "  mutate(layer_idx = edit_central_layer) %>% # need to add layer_idx for matching tracing by layer_idx\n",
    "  left_join(\n",
    "    max_per_layer_per_record,\n",
    "    join_by=c(\"case_id\", \"layer_idx\")\n",
    "  ) %>%\n",
    "  left_join(\n",
    "    subj_effect_per_layer_per_record,\n",
    "    join_by=c(\"case_id\", \"layer_idx\")\n",
    "  )\n",
    "# add indicator for if tracing effect is in the 99th percentile of tracing effects\n",
    "effect_cutoff <- quantile(data %>% filter(module == 'mlp') %>% pull(restore_effect), .95)\n",
    "# add discrete restoration effect variable\n",
    "per_edit_data <- per_edit_data %>%\n",
    "  mutate(large_tracing_effect = max_token_effect > effect_cutoff,\n",
    "         orig_pred_prob_disc = ifelse(orig_pred_prob < .05, \"<.05\", \n",
    "                    ifelse(orig_pred_prob < .1, \".05-.1\", \n",
    "                    ifelse(orig_pred_prob < .15, \".1-.15\", \n",
    "                    ifelse(orig_pred_prob < .2, \".15-.2\", \n",
    "                    ifelse(orig_pred_prob < .25, \".2-.25\", '>.25')))))\n",
    "  )\n",
    "\n",
    "editing_data %>%\n",
    "  select(edit_method, edit_window_size, objective, case_id) %>%\n",
    "  unique() %>%\n",
    "  group_by(edit_method, edit_window_size, objective) %>%\n",
    "  summarise(n = n())\n",
    "# editing_data %>%\n",
    "#   group_by(edit_method, edit_window_size, objective, case_id, edit_central_layer) %>%\n",
    "#   unique() %>%\n",
    "#   group_by(edit_method, edit_window_size, objective, edit_central_layer) %>%\n",
    "#   summarise(n = n())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "correctness_filter = 1\n",
    "\n",
    "editing_data_table <- per_edit_data %>%\n",
    "  filter(is_correct >= correctness_filter) %>%\n",
    "  group_by(edit_method, edit_central_layer, edit_window_size, objective) %>%\n",
    "  summarise(n=n(),\n",
    "            rewrite_score_sd = sd(rewrite_score),\n",
    "            rewrite_score = mean(rewrite_score),\n",
    "            paraphrase_score = mean(paraphrase_score),\n",
    "            neighborhood_score = mean(neighborhood_score),\n",
    "            target_score_mean = mean(target_score_mean),\n",
    "            target_score_mean_v2 = mean(target_score_mean_v2),\n",
    "            essence_ppl_diff = mean(essence_ppl_diff),\n",
    "            essence_score = mean(essence_diff_normed),\n",
    "            ) %>%\n",
    "  arrange(edit_method) %>%\n",
    "  arrange(edit_central_layer) %>%\n",
    "  arrange(edit_window_size) %>%\n",
    "  mutate_if(is.double, ~round(., 3))\n",
    "\n",
    "editing_data_table %>%\n",
    "  select(-rewrite_score_sd) %>%\n",
    "  # filter(is_correct >= correctness_filter) %>%\n",
    "  # filter(edit_method == \"FT\") %>%\n",
    "  # filter(objective == \"Fact Forcing\") %>%\n",
    "  filter(objective == \"Tracing Reversal\") %>%\n",
    "  # filter(objective == \"Falsehood Injection\") %>%\n",
    "  # filter(edit_window_size == 5) %>%\n",
    "  arrange(edit_method, objective, edit_window_size, edit_central_layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Performance plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "options(repr.plot.width=11, repr.plot.height=6)\n",
    "editing_data_table <- editing_data_table %>%\n",
    "  mutate(Method = sprintf(\"%s (ws=%s)\", edit_method, edit_window_size),\n",
    "         Method = factor(Method, levels=c(\"FT (ws=1)\", \"FT (ws=5)\", \"ROME (ws=1)\", \"MEMIT (ws=5)\")))\n",
    "xticks = seq(0, 28, by=4)\n",
    "xticks[1] <- 1\n",
    "\n",
    "OBJECTIVE = \"Falsehood Injection\"\n",
    "# OBJECTIVE = \"Tracing Reversal\"\n",
    "# OBJECTIVE = \"Fact Erasure\"\n",
    "# OBJECTIVE = \"Fact Amplification\"\n",
    "# OBJECTIVE = \"Fact Forcing\"\n",
    "line_size = 1\n",
    "\n",
    "TITLE <- sprintf(\"%s Rewrite Score by Edit Layer\", OBJECTIVE)\n",
    "if (OBJECTIVE == \"Falsehood Injection\"){\n",
    "  TITLE <- \"Error Injection Rewrite Score by Edit Layer\"\n",
    "}\n",
    "(rewrite_plot <- editing_data_table %>%\n",
    "  filter(objective == OBJECTIVE) %>%\n",
    "  ggplot(aes(edit_central_layer, rewrite_score, color=Method)) + \n",
    "  geom_line(size=line_size) +\n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"(Central) Edit Layer\") + \n",
    "  ylab(\"Rewrite Score\") + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  theme + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        legend.title = element_text(size=16),\n",
    "        legend.position=c(0.2, 0.2),  \n",
    "        legend.background = element_rect(fill = \"white\", color = \"#555555\"),\n",
    "        legend.key = element_blank())\n",
    ")\n",
    "\n",
    "TITLE <- sprintf(\"%s Paraphrase score by edit layer\", OBJECTIVE)\n",
    "if (OBJECTIVE == \"Falsehood Injection\"){\n",
    "  TITLE <- \"Error Injection Paraphrase Score by Edit Layer\"\n",
    "}\n",
    "(paraphrase_plot <- editing_data_table %>%\n",
    "  filter(objective == OBJECTIVE) %>%\n",
    "  ggplot(aes(edit_central_layer, paraphrase_score, color=Method)) + \n",
    "  geom_line(size=line_size) +\n",
    "  ylim(c(0, 1)) + \n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"(Central) Edit Layer\") + \n",
    "  ylab(\"Paraphrase Score\") + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  theme + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        legend.title = element_text(size=16),\n",
    "        legend.position=c(0.8, 0.8),  \n",
    "        legend.background = element_rect(fill = \"white\", color = \"#555555\"),\n",
    "        legend.key = element_blank())\n",
    ")\n",
    "\n",
    "TITLE <- sprintf(\"%s Neighbor score by edit layer\", OBJECTIVE)\n",
    "if (OBJECTIVE == \"Falsehood Injection\"){\n",
    "  TITLE <- \"Error Injection Neighborhood Score by Edit Layer\"\n",
    "}\n",
    "(neighborhood_plot <- editing_data_table %>%\n",
    "  filter(objective == OBJECTIVE) %>%\n",
    "  ggplot(aes(edit_central_layer, neighborhood_score, color=Method)) + \n",
    "  geom_line(size=line_size) +\n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"(Central) Edit Layer\") + \n",
    "  ylab(\"Neighborhood Score\") + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  ylim(c(.9, 1)) + \n",
    "  theme + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        legend.title = element_text(size=16),\n",
    "        legend.position=c(0.2, 0.2),  \n",
    "        legend.background = element_rect(fill = \"white\", color = \"#555555\"),\n",
    "        legend.key = element_blank())\n",
    ")\n",
    "\n",
    "TITLE <- sprintf(\"%s Essence Score by Edit Layer\", OBJECTIVE)\n",
    "if (OBJECTIVE == \"Falsehood Injection\"){\n",
    "  TITLE <- \"Error Injection Essence Score by Edit Layer\"\n",
    "}\n",
    "(essence_plot <- editing_data_table %>%\n",
    "  filter(objective == OBJECTIVE) %>%\n",
    "  ggplot(aes(edit_central_layer, essence_score, color=Method)) + \n",
    "  geom_line(size=line_size) +\n",
    "  ggtitle(TITLE) + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  ylim(c(.4, 1)) + \n",
    "  xlab(\"(Central) Edit Layer\") + \n",
    "  ylab(\"Essence Score\") + \n",
    "  theme + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        legend.title = element_text(size=16),\n",
    "        legend.position=c(0.8, 0.2),  \n",
    "        legend.background = element_rect(fill = \"white\", color = \"#555555\"),\n",
    "        legend.key = element_blank())\n",
    ")\n",
    "\n",
    "TITLE <- sprintf(\"%s Overall Score by Edit Layer\", OBJECTIVE)\n",
    "if (OBJECTIVE == \"Falsehood Injection\"){\n",
    "  TITLE <- \"Error Injection Overall Score by Edit Layer\"\n",
    "}\n",
    "(ovr_plot <- editing_data_table %>%\n",
    "  filter(objective == OBJECTIVE) %>%\n",
    "  ggplot(aes(edit_central_layer, target_score_mean, color=Method)) + \n",
    "  geom_line(size=line_size) + \n",
    "  ggtitle(TITLE) +  \n",
    "  xlab(\"(Central) Edit Layer\") + \n",
    "  ylab(\"Overall Score\") + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  ylim(c(.3, 1)) + \n",
    "  theme + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        legend.title = element_text(size=16),\n",
    "        legend.position=c(0.2, 0.2),  \n",
    "        legend.background = element_rect(fill = \"white\", color = \"#555555\"),\n",
    "        legend.key = element_blank())\n",
    ")\n",
    "\n",
    "TITLE <- sprintf(\"%s Overall Score by Edit Layer\", OBJECTIVE)\n",
    "if (OBJECTIVE == \"Falsehood Injection\"){\n",
    "  TITLE <- \"Error Injection Overall Score (+Essence) by Edit Layer\"\n",
    "}\n",
    "(ovr_v2_plot <- editing_data_table %>%\n",
    "  filter(objective == OBJECTIVE) %>%\n",
    "  ggplot(aes(edit_central_layer, target_score_mean_v2, color=Method)) + \n",
    "  geom_line(size=line_size) +\n",
    "  ggtitle(TITLE) + \n",
    "  xlab(\"(Central) Edit Layer\") + \n",
    "  ylab(\"Overall Score\") + \n",
    "  scale_x_continuous(labels = xticks, breaks=xticks-1) + \n",
    "  ylim(c(.3, 1)) + \n",
    "  theme + \n",
    "  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),\n",
    "        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),\n",
    "        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "        legend.title = element_text(size=16),\n",
    "        legend.position=c(0.2, 0.2),  \n",
    "        legend.background = element_rect(fill = \"white\", color = \"#555555\"),\n",
    "        legend.key = element_blank())\n",
    ")\n",
    "\n",
    "ggsave('rewrite_plot.pdf', rewrite_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('rewrite_plot.pdf')\n",
    "ggsave('paraphrase_plot.pdf', paraphrase_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('paraphrase_plot.pdf')\n",
    "ggsave('neighborhood_plot.pdf', neighborhood_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('neighborhood_plot.pdf')\n",
    "ggsave('essence_plot.pdf', essence_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('essence_plot.pdf')\n",
    "ggsave('ovr_plot.pdf', ovr_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('ovr_plot.pdf')\n",
    "ggsave('ovr_v2_plot.pdf', ovr_v2_plot, width=8, height=6, device=cairo_pdf)\n",
    "colab::download_file('ovr_v2_plot.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## R2 values for each method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "conditions <- list()\n",
    "conditions[[1]] <- c(\"FT\", \"1\")\n",
    "conditions[[2]] <- c(\"FT\", \"5\")\n",
    "conditions[[3]] <- c(\"ROME\", \"1\")\n",
    "conditions[[4]] <- c(\"MEMIT\", \"5\")\n",
    "objectives <- c(\"Falsehood Injection\", \"Tracing Reversal\", \"Fact Erasure\", \"Fact Forcing\", \"Fact Amplification\")\n",
    "correctness_filter = 1\n",
    "\n",
    "results_df <- data.frame(\n",
    "  edit_method = character(),\n",
    "  objective = character(),\n",
    "  edit_window_size = character(),\n",
    "  rewrite_R2_lyr = double(), \n",
    "  rewrite_R2_trace = double(), \n",
    "  rewrite_R2_both = double(), \n",
    "  paraphrase_R2_lyr = double(), \n",
    "  paraphrase_R2_trace = double(), \n",
    "  paraphrase_R2_both = double(), \n",
    "  neighborhood_R2_lyr = double(), \n",
    "  neighborhood_R2_trace = double(), \n",
    "  neighborhood_R2_both = double(), \n",
    "  essence_R2_lyr = double(), \n",
    "  essence_R2_trace = double(), \n",
    "  essence_R2_both = double()\n",
    ")\n",
    "\n",
    "for (condition in conditions){\n",
    "  METHOD <- condition[1]\n",
    "  EDIT_WINDOW_SIZE <- condition[2]\n",
    "  for (OBJECTIVE in objectives){\n",
    "    model_data <- per_edit_data %>%\n",
    "      filter(is_correct >= correctness_filter) %>%\n",
    "      filter(edit_method == METHOD, edit_window_size == EDIT_WINDOW_SIZE, objective == OBJECTIVE) %>%\n",
    "      mutate(essence_score = essence_diff_normed)\n",
    "    if (nrow(model_data) == 0) next\n",
    "    row <- data.frame(\n",
    "        edit_method=METHOD,\n",
    "        objective=OBJECTIVE,\n",
    "        edit_window_size=EDIT_WINDOW_SIZE\n",
    "      )\n",
    "    for (metric in c(\"rewrite\", \"paraphrase\", \"neighborhood\", \"essence\")){\n",
    "      condition <- sprintf(\"%s | %s | %s | %s\", METHOD, OBJECTIVE, EDIT_WINDOW_SIZE, metric)\n",
    "      # print(condition)\n",
    "      score_name <- sprintf(\"%s_score\", metric)\n",
    "      model_edit_layer_only <- lm(get(score_name) ~ as.factor(edit_central_layer), data = model_data)\n",
    "      r2_lyr <- summary(model_edit_layer_only)$r.squared\n",
    "      model_fraction_restored_only <- lm(get(score_name) ~ max_fraction_restored, data = model_data)\n",
    "      r2_trace <- summary(model_fraction_restored_only)$r.squared\n",
    "      model_both <- lm(get(score_name) ~ as.factor(edit_central_layer) * max_fraction_restored, data = model_data)\n",
    "      r2_both <- summary(model_both)$r.squared\n",
    "      f_test <- anova(model_both, model_edit_layer_only)\n",
    "      p_value <- f_test$`Pr(>F)`[2]\n",
    "      row[[sprintf(\"%s_R2_lyr\", metric)]] = r2_lyr\n",
    "      row[[sprintf(\"%s_R2_trace\", metric)]] = r2_trace\n",
    "      row[[sprintf(\"%s_R2_both\", metric)]] = r2_both\n",
    "      row[[sprintf(\"%s_R2_diff\", metric)]] = r2_both - r2_lyr\n",
    "      row[[sprintf(\"%s_F_pvalue\", metric)]] <- p_value\n",
    "    }\n",
    "    results_df <- bind_rows(results_df, row)\n",
    "    }\n",
    "}\n",
    "\n",
    "results_df <- results_df %>%\n",
    "  mutate_if(is.double, ~round(.,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "results_df %>%\n",
    "  mutate_if(is.double, ~round(.,4)) %>%\n",
    "  # filter(objective == \"Falsehood Injection\" | objective == \"Fact Forcing\") %>%\n",
    "  # filter(objective == \"Fact Forcing\") %>%\n",
    "  select(objective, edit_method, edit_window_size, rewrite_R2_diff, rewrite_F_pvalue, paraphrase_R2_diff, paraphrase_F_pvalue, neighborhood_R2_diff, neighborhood_F_pvalue, essence_R2_diff, essence_F_pvalue) %>%\n",
    "  # select(objective, edit_method, edit_window_size, rewrite_R2_lyr, rewrite_R2_trace, rewrite_R2_both, rewrite_R2_diff, rewrite_F_pvalue) %>%\n",
    "  # select(objective, edit_method, edit_window_size, rewrite_F_pvalue) %>%\n",
    "  # arrange(rewrite_R2_diff) %>%\n",
    "  # arrange(paraphrase_R2_diff) %>%\n",
    "  # arrange(neighborhood_R2_diff) %>%\n",
    "  # arrange(essence_R2_diff) %>%\n",
    "  # arrange(rewrite_R2_lyr) %>%\n",
    "  arrange(objective, edit_method, edit_window_size) %>%\n",
    "  write_csv('tmp.csv') %>%\n",
    "  filter()\n",
    "\n",
    "colab::download_file('tmp.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Final R2 plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "options(repr.plot.width=13, repr.plot.height=4)\n",
    "\n",
    "obj = 'rewrite'\n",
    "R2_lyr_name = sprintf(\"%s_R2_lyr\", obj)\n",
    "R2_diff_name = sprintf(\"%s_R2_diff\", obj)\n",
    "plot_data <- results_df %>%\n",
    "  mutate(method=sprintf('%s-%s', edit_method, edit_window_size)) %>%\n",
    "  select(objective, method, all_of(R2_lyr_name), all_of(R2_diff_name)) %>%\n",
    "  pivot_longer(cols=c(R2_lyr_name, R2_diff_name), names_to = \"model\", values_to = \"R2\") %>%\n",
    "  mutate(model = ifelse(model == R2_lyr_name, \"Layer\", \"Layer + Tracing Effect\"),\n",
    "         objective = ifelse(objective == \"Falsehood Injection\", \"Error Injection\", objective),\n",
    "         objective = factor(objective, levels=c(\"Error Injection\",\"Tracing Reversal\",\"Fact Amplification\", \"Fact Erasure\",\"Fact Forcing\")),\n",
    "         model = factor(model, levels = c(\"Layer + Tracing Effect\", \"Layer\")),\n",
    "         method = ifelse(method == \"ROME-1\", \"ROME\", ifelse(method==\"MEMIT-5\", \"MEMIT\", method)),\n",
    "         method = factor(method, levels = c(\"FT-1\", \"FT-5\", \"ROME\", \"MEMIT\"))\n",
    "  )\n",
    "# this is an overly complicated way to get the 'p<1e-4' text to appear on only one facet of the plot, a single time rather than once per bar\n",
    "text_data <- plot_data %>%\n",
    "  #filter(objective == \"Fact Forcing\", grepl(\"FT\", method, fixed=TRUE)) %>%\n",
    "  filter(objective == \"Fact Forcing\", method==\"FT-1\", model==\"Layer\") %>%\n",
    "  mutate(ann_text = factor(\"Fact Forcing\", levels = c(\"Error Injection\",\"Tracing Reversal\",\"Fact Amplification\", \"Fact Erasure\",\"Fact Forcing\")))\n",
    "(bar_plot <- plot_data %>%\n",
    "  filter(objective != \"Error Injection\") %>%\n",
    "  ggplot(aes(x=method, y=R2, fill=model)) +\n",
    "    geom_col(position='stack', width = 0.5) + \n",
    "    geom_text(data=text_data, \n",
    "              # mapping = aes(x=-Inf, y=-Inf, label=label),\n",
    "              label = \"p < 1e-4\",\n",
    "              size=5,\n",
    "              # size=16,\n",
    "              hjust = 0, vjust=-1.5) + \n",
    "    xlab(\"\") + \n",
    "    ylab(bquote(R^2)) + \n",
    "    ylim(c(0,1)) + \n",
    "    scale_fill_manual(name = \"Explanatory Variable(s):\", values=c(cbp1[2], cbp1[1]), limits = c(\"Layer\", \"Layer + Tracing Effect\")) +\n",
    "    theme + \n",
    "    ggtitle(paste(\"Tracing effects are very weakly predictive of edit success\")) + \n",
    "    # ggtitle(paste(\"Tracing effects are very weakly predictive of edit success (measured by rewrite score\")) + \n",
    "    theme(axis.title.y = element_text(size=23, angle=0, vjust=0.45, hjust=-2),\n",
    "          # axis.text.x = element_text(angle=40, vjust=.65, hjust=.5),\n",
    "          axis.text.y = element_text(size=18),\n",
    "          axis.text.x = element_text(size=16, angle=0, vjust=0, hjust=.5),\n",
    "          axis.title.x = element_text(size=23, angle=0, vjust=-1, hjust=.5),\n",
    "          plot.title = element_text(size=24, angle=0, vjust=1, hjust=.5),\n",
    "          strip.text.x = element_text(size=16, color='black', angle=0, vjust=.5, hjust=.5),\n",
    "          legend.position=c(.5, -0.25),  \n",
    "          legend.direction='horizontal',\n",
    "          legend.title = element_text(size=22, vjust=.55),\n",
    "          legend.text = element_text(size=22),\n",
    "          legend.box.margin = margin(100,100,100,100),\n",
    "        ) + \n",
    "    facet_wrap(~objective, nrow=1)\n",
    ")\n",
    "\n",
    "# ggsave('R2-bar-plot.pdf', bar_plot, width=17, height=4, device=cairo_pdf)\n",
    "ggsave('R2-bar-plot.pdf', bar_plot, width=14.5, height=3.6, device=cairo_pdf)\n",
    "colab::download_file('R2-bar-plot.pdf') "
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
