{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import glob\n",
    "import re\n",
    "import sys\n",
    "from typing import Dict, List\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "import plotly.io as pio\n",
    "import chart_studio\n",
    "from chart_studio import plotly as py\n",
    "\n",
    "from IPython.display import display, clear_output, HTML\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "from utils.data_processing import (\n",
    "    load_edge_scores_into_dictionary,\n",
    "    compute_weighted_jaccard_similarity,\n",
    "    compute_weighted_jaccard_similarity_to_reference,\n",
    "    compute_ewma_weighted_jaccard_similarity,\n",
    "    generate_in_circuit_df_files\n",
    ")\n",
    "from utils.result_plotting import plot_head_circuit_scores, plot_graph_metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ignore the following if not using chart_studio. if you do want to publish graphs, simple include upload=True in the plotly function\n",
    "\n",
    "# load API key from local file chart_studio_api_key.txt - should be username and api key separated by a comma\n",
    "with open(\"../auth/chart_studio_api_key.txt\") as f:\n",
    "    username, api_key = f.read().strip().split(\",\")\n",
    "    # strip leading whitespace\n",
    "    username = username.strip()\n",
    "    api_key = api_key.strip()\n",
    "\n",
    "chart_studio.tools.set_credentials_file(username=username, api_key=api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = 'ioi'\n",
    "PERFORMANCE_METRIC = 'logit_diff'\n",
    "MODEL_NAME = 'pythia-2.8b'\n",
    "OUTPUT_DIR = f\"../results/plots/graph_metrics/{TASK}/\"\n",
    "\n",
    "# create output directory\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ONLY NEEDS TO BE RUN IF EAP IS REPEATED FOR MODEL/TASK OR NEW CHECKPOINTS ARE ADDED\n",
    "\n",
    "# generate_in_circuit_df_files('../results/graphs', start_checkpoint=3000, limit_to_model=MODEL_NAME, limit_to_task=TASK)\n",
    "# clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load circuit graph dataframe from file\n",
    "in_circuit_df = pd.read_feather(f'../results/graphs/{MODEL_NAME}/{TASK}/in_circuit_edges.feather')\n",
    "edge_count_df = in_circuit_df.groupby('checkpoint').size().reset_index(name='num_edges')\n",
    "\n",
    "# load performance metrics, e.g. logit diff, from file\n",
    "# check if file exists\n",
    "performance_metrics_file = '../results/task_performance_metrics/all_models_task_performance.pt'\n",
    "perf_metrics_by_model = torch.load(performance_metrics_file)\n",
    "if MODEL_NAME in perf_metrics_by_model and TASK in perf_metrics_by_model[MODEL_NAME]:\n",
    "    # The following can be replaced with any dictionary with (checkpoint: metric) structure,\n",
    "    # e.g. from baselines or other model task runs\n",
    "    perf_metric_dict = perf_metrics_by_model[MODEL_NAME][TASK][PERFORMANCE_METRIC]\n",
    "\n",
    "else:\n",
    "    perf_metric_dict = {c: 0.0 for c in in_circuit_df['checkpoint'].unique()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_circuit_df.head()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graph_metric(\n",
    "    edge_count_df, \n",
    "    'num_edges', \n",
    "    perf_metric_dict, \n",
    "    f'Graph Size for {MODEL_NAME}', \n",
    "    right_y_title=\"Logit Diff\",\n",
    "    y_ranges=((0, 1500), (0, 6)), \n",
    "    left_y_title=\"Edge Count\", \n",
    "    x_axis_col='checkpoint', \n",
    "    log_x=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weighted_jaccard_results = compute_weighted_jaccard_similarity(in_circuit_df)\n",
    "plot_graph_metric(\n",
    "    weighted_jaccard_results, \n",
    "    'jaccard_similarity', \n",
    "    perf_metric_dict, \n",
    "    f'Jaccard Similarity for {MODEL_NAME}', \n",
    "    y_ranges=((0, 1), (0, 6)), \n",
    "    left_y_title=\"Jaccard Similarity\", \n",
    "    x_axis_col='checkpoint_2', \n",
    "    log_x=True,\n",
    "    metric_legend_name=\"Jaccard Sim\",\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comparison_checkpoint = 5000 \n",
    "jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(in_circuit_df, comparison_checkpoint)\n",
    "plot_graph_metric(\n",
    "    jaccard_reference_results, \n",
    "    'jaccard_similarity', \n",
    "    perf_metric_dict, \n",
    "    f'Weighted Jaccard Similarity to Checkpoint {comparison_checkpoint} for {MODEL_NAME}', \n",
    "    y_ranges=((0, 1), (0, 6)), \n",
    "    left_y_title=\"Jaccard Similarity\", \n",
    "    x_axis_col='checkpoint', \n",
    "    log_x=True,\n",
    "    metric_legend_name=\"Jaccard Sim\",\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comparison_checkpoint = 143000\n",
    "\n",
    "jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(in_circuit_df, comparison_checkpoint)\n",
    "plot_graph_metric(\n",
    "    jaccard_reference_results, \n",
    "    'jaccard_similarity', \n",
    "    perf_metric_dict, \n",
    "    f'Weighted Jaccard Similarity to Checkpoint {comparison_checkpoint} for {MODEL_NAME}', \n",
    "    y_ranges=((0, 1), (0, 6)), \n",
    "    left_y_title=\"Jaccard Similarity\",\n",
    "    x_axis_col='checkpoint', \n",
    "    log_x=True,\n",
    "    metric_legend_name=\"Jaccard Sim\",\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jaccard_reference_results = compute_ewma_weighted_jaccard_similarity(in_circuit_df, alpha=0.1)\n",
    "plot_graph_metric(\n",
    "    jaccard_reference_results, \n",
    "    'ewma_change_rate', \n",
    "    perf_metric_dict, \n",
    "    f'Exponential Weighted Average Graph Change Rate for {MODEL_NAME}', \n",
    "    y_ranges=((0, 1), (0, 6)), \n",
    "    left_y_title=\"Jaccard Similarity\",\n",
    "    x_axis_col='checkpoint_2', \n",
    "    log_x=True,\n",
    "    metric_legend_name=\"Jaccard Sim\",\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
