{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0722ea91-f8f3-4267-9c46-d7c7e9f3f7da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "import torch\n",
    "import os\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "os.environ[\"HF_TOKEN\"] = global_cfg['hf_access_token']\n",
    "from transformer_lens import HookedTransformer\n",
    "from transformers import AutoTokenizer\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from functools import partial\n",
    "import scipy\n",
    "from scipy.stats import spearmanr, pearsonr\n",
    "\n",
    "from stitching.losses import next_token_cross_entropy_loss, kl_div_loss\n",
    "from stitching.stitching_utils import open_experiment, load_activation_store\n",
    "from stitching.sae_utils import BaseSAE, forward_modified, generate_modified, max_csim_transfer_to_orig, convert_eleuther_sae_to_BaseSAE\n",
    "from stitching.losses import get_ignore_mask\n",
    "from stitching.generic_experiments import *\n",
    "from stitching.sae_utils import feature_activations, precision_recall_f1, jumprelu_activation\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset\n",
    "device ='cuda'\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fce064d5-e9fb-4fcb-a6de-46274f9943d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_name_id = 'fallen-glitter-8'\n",
    "acts_dir = '[FILL IN]'\n",
    "modelA_name = 'gemma-2-2b'\n",
    "modelB_name = 'gemma-2-9b'\n",
    "checkpoints_dir = f'checkpoints/stitch_training_{modelA_name}_to_{modelB_name}_bidirectional_mse'\n",
    "layer_A = 20\n",
    "layer_B = 33\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20d01d69-5285-423a-9fb7-7c16e7ca1ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_dir = f'activations_store/fr_fixed_attribution_correlation/{run_name_id}/'\n",
    "os.makedirs(activations_dir,exist_ok=True)\n",
    "print(activations_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ad2f761-07eb-445f-b135-e1e91f0aed71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get feature activations for gemma-2-2b\n",
    "P, Pinv, beta, bias, biasinv = open_experiment(2304, 3584, checkpoints_dir, run_name_id, biases=True, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e222e041-7b01-4d30-9539-575dc87ad99f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a50b96fc-5fca-4124-9c36-4bf026f2e61c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load and transfer sae\n",
    "layer_A = 20\n",
    "sae_A, cfg_dict, _ = SAE.from_pretrained(\n",
    "    release = 'gemma-scope-2b-pt-res-canonical', # see other options in sae_lens/pretrained_saes.yaml\n",
    "    sae_id = f\"layer_{layer_A-1}/width_16k/canonical\", # won't always be a hook point\n",
    "    device = device\n",
    ")\n",
    "if 'BaseSAE' in str(type(sae_A)):\n",
    "    apply_b_dec = sae_A.apply_b_dec\n",
    "else:\n",
    "    apply_b_dec = sae_A.cfg.apply_b_dec_to_input\n",
    "thresholds = sae_A.threshold\n",
    "orig_sae_A = BaseSAE(\n",
    "    sae_A.W_enc.detach().clone(),\n",
    "    sae_A.W_dec.detach().clone(),\n",
    "    sae_A.b_enc.detach().clone(),\n",
    "    sae_A.b_dec.detach().clone(),\n",
    "    partial(jumprelu_activation, thresholds=thresholds),\n",
    "    apply_b_dec=apply_b_dec\n",
    ")\n",
    "# both done already.\n",
    "#orig_sae_A.normalize_decoder_vectors()\n",
    "#orig_sae_A.get_rid_of_decoder_sub()\n",
    "\n",
    "\n",
    "transferred_sae_B = BaseSAE(\n",
    "    Pinv @ orig_sae_A.W_enc.detach().clone(),\n",
    "    orig_sae_A.W_dec.detach().clone() @ P,\n",
    "    orig_sae_A.b_enc.detach().clone() + biasinv @ orig_sae_A.W_enc.detach().clone(),\n",
    "    orig_sae_A.b_dec.detach().clone() @ P + bias,\n",
    "    orig_sae_A.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "#transferred_sae_B.normalize_decoder_vectors()\n",
    "random_sae = BaseSAE(\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros_like(transferred_sae_B.W_enc)),\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros_like(transferred_sae_B.W_dec)),\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros((1, transferred_sae_B.b_enc.shape[0]))).flatten(),\n",
    "    torch.nn.init.kaiming_uniform_(torch.zeros((1, transferred_sae_B.b_dec.shape[0]))).flatten(),\n",
    "    sae_A.activation_fn\n",
    ")\n",
    "random_sae = random_sae.to(sae_A.device)\n",
    "#random_sae.normalize_decoder_vectors()\n",
    "#random_sae.get_rid_of_decoder_sub()\n",
    "del sae_A, transferred_sae_B\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f90a11c2-0e76-4587-8c8e-8beb94341756",
   "metadata": {},
   "source": [
    "# Part 1: Cache logits weights for all features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef577ffb-5b33-4fad-bc05-2997db7343c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#modelA = HookedTransformer.from_pretrained(modelA_name, cache_dir=CACHE_DIR, device=device, torch_dtype=torch.float16)\n",
    "modelB = HookedTransformer.from_pretrained(modelB_name, cache_dir=CACHE_DIR, device=device, torch_dtype=torch.float16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d4b7631-1015-4581-a703-2b7428353609",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_batch_size = 1000\n",
    "for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "    feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "\n",
    "    #  NORMALIZE WEIGHTS\n",
    "    dec = random_sae.W_dec[feature_subset] / random_sae.W_dec[feature_subset].norm(dim=-1, keepdim=True)\n",
    "    random_logit_weights_B = modelB.unembed(modelB.ln_final(dec)).cpu().numpy()\n",
    "    np.save(os.path.join(activations_dir, f'random_logit_weights_B_size_{feature_batch_size}_batch_{i}.npy'), random_logit_weights_B)\n",
    "    #logit_weights_A = modelA.unembed(modelA.ln_final(orig_sae_A.W_dec[feature_subset])).cpu().numpy()\n",
    "    #np.save(os.path.join(activations_dir, f'logit_weights_A_size_{feature_batch_size}_batch_{i}.npy'), logit_weights_A)\n",
    "    #logit_weights_B = modelB.unembed(modelB.ln_final(transferred_sae_B.W_dec[feature_subset])).cpu().numpy()\n",
    "    #np.save(os.path.join(activations_dir, f'logit_weights_B_size_{feature_batch_size}_batch_{i}.npy'), logit_weights_B)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "999bd705-7c1a-44a6-ae38-5d906a3caf0e",
   "metadata": {},
   "source": [
    "# Part 2: Compute stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94e1d60f-af14-4504-97c0-c6139c8bc5ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3811b2c-fe47-4284-8f56-c0b3f7ee9814",
   "metadata": {},
   "outputs": [],
   "source": [
    "# just do it over the first sample of activations (a lot of tokens)\n",
    "tokens = torch.load(\n",
    "    os.path.join(\n",
    "        acts_dir,\n",
    "        f'tokens_1.pt'\n",
    "    ),\n",
    "    weights_only=True\n",
    ")\n",
    "activations_A = torch.load(\n",
    "    os.path.join(\n",
    "        acts_dir,\n",
    "        f\"{modelA_name}/\",\n",
    "        f'{modelA_name}_layer_{layer_A}_cached_activations_1.pt'\n",
    "    ), weights_only=True\n",
    ").float()\n",
    "activations_B = torch.load(\n",
    "    os.path.join(\n",
    "        acts_dir,\n",
    "        f\"{modelB_name}/\",\n",
    "        f'{modelB_name}_layer_{layer_B}_cached_activations_1.pt'\n",
    "    ),\n",
    "    weights_only=True\n",
    ").float()\n",
    "\n",
    "feature_batch_size = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc0fe02-7fdd-4842-927e-26929ce4e09e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader1 = torch.utils.data.DataLoader(activations_A, batch_size = 1, shuffle=False)\n",
    "dataloader2 = torch.utils.data.DataLoader(activations_B, batch_size = 1, shuffle=False)\n",
    "tokens_dataloader = torch.utils.data.DataLoader(tokens, batch_size = 1, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebfc3976-6c72-4497-9ade-bf14266f654c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Actual B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "810540e5-04a8-4df4-81f7-cbe003b83e05",
   "metadata": {},
   "outputs": [],
   "source": [
    "spec_tokens = get_all_special_tokens(tokenizer) # get special tokens\n",
    "\n",
    "for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "    print(\"batch\", i // feature_batch_size)\n",
    "    feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "\n",
    "    # might need to cache the logit weights...\n",
    "    logit_weights_A = torch.tensor(np.load(os.path.join(activations_dir, f'logit_weights_A_size_{feature_batch_size}_batch_{i}.npy'))).to(device)\n",
    "    logit_weights_B = torch.tensor(np.load(os.path.join(activations_dir, f'logit_weights_B_size_{feature_batch_size}_batch_{i}.npy'))).to(device)\n",
    "    \n",
    "    # and iterate over the activations\n",
    "    all_scores_A = []\n",
    "    all_scores_B = []\n",
    "    for _, (tokens, resid_A, resid_B) in tqdm(enumerate(zip(tokens_dataloader, dataloader1, dataloader2))):\n",
    "        tokens = tokens.to(device)\n",
    "        resid_A = resid_A.to(device)\n",
    "        resid_B = resid_B.to(device)\n",
    "        cur_tokens = tokens[..., :-1]\n",
    "        next_tokens = tokens[..., 1:]\n",
    "        resid_A = resid_A[..., :-1, :]\n",
    "        resid_B = resid_B[..., :-1, :]\n",
    "        ignore_mask = get_ignore_mask(tokens, spec_tokens) # (b, nseq)\n",
    "        ignore_mask = torch.logical_or(ignore_mask[:, 1:], ignore_mask[:, :-1])  # (b, nseq-1) # either this token or next token is padding\n",
    "        acts_A = orig_sae_A.encode(resid_A)[..., feature_subset]\n",
    "        acts_B = transferred_sae_B.encode(resid_B)[..., feature_subset]\n",
    "        acts_A = acts_A[~ignore_mask]\n",
    "        acts_B = acts_B[~ignore_mask]\n",
    "        next_tokens = next_tokens[~ignore_mask]\n",
    "        all_scores_A.append((acts_A * logit_weights_A[:, next_tokens].T).cpu().numpy())\n",
    "        all_scores_B.append((acts_B * logit_weights_B[:, next_tokens].T).cpu().numpy())\n",
    "\n",
    "    scores_A = np.concatenate(all_scores_A)\n",
    "    scores_B = np.concatenate(all_scores_B)\n",
    "\n",
    "    corrs_attribution = []\n",
    "    print(\"Computing correlations\")\n",
    "    for j in tqdm(range(scores_A.shape[1])):\n",
    "        pearson_res_attribution = pearsonr(scores_A[:,j], scores_B[:,j], axis=0)\n",
    "        corrs_attribution.append(pearson_res_attribution.statistic)\n",
    "    results_dict = {\n",
    "        'attribution_correlation': np.array(corrs_attribution)\n",
    "    }\n",
    "    np.savez(os.path.join(activations_dir, f'metrics_size_{feature_batch_size}_batch_{i}.npz'), **results_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e9b4a84-1b9e-4596-a806-d5866b680349",
   "metadata": {},
   "source": [
    "### Random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fccd18d-51f9-4eb2-b3d4-ce6376bebbce",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_sae.W_enc.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0df2820d-689d-44c6-b93b-44a5cf0a13a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "spec_tokens = get_all_special_tokens(tokenizer) # get special tokens\n",
    "\n",
    "for i in np.arange(0, orig_sae_A.d_sae, feature_batch_size):\n",
    "    print(\"batch\", i // feature_batch_size)\n",
    "    feature_subset = np.arange(i, min(orig_sae_A.d_sae, i+feature_batch_size))\n",
    "\n",
    "    # might need to cache the logit weights...\n",
    "    logit_weights_A = torch.tensor(np.load(os.path.join(activations_dir, f'logit_weights_A_size_{feature_batch_size}_batch_{i}.npy'))).to(device)\n",
    "    logit_weights_B = torch.tensor(np.load(os.path.join(activations_dir, f'random_logit_weights_B_size_{feature_batch_size}_batch_{i}.npy'))).to(device)\n",
    "    \n",
    "    # and iterate over the activations\n",
    "    all_scores_A = []\n",
    "    all_scores_B = []\n",
    "    for _, (tokens, resid_A, resid_B) in tqdm(enumerate(zip(tokens_dataloader, dataloader1, dataloader2))):\n",
    "        tokens = tokens.to(device)\n",
    "        resid_A = resid_A.to(device)\n",
    "        resid_B = resid_B.to(device)\n",
    "        cur_tokens = tokens[..., :-1]\n",
    "        next_tokens = tokens[..., 1:]\n",
    "        resid_A = resid_A[..., :-1, :]\n",
    "        resid_B = resid_B[..., :-1, :]\n",
    "        ignore_mask = get_ignore_mask(tokens, spec_tokens) # (b, nseq)\n",
    "        ignore_mask = torch.logical_or(ignore_mask[:, 1:], ignore_mask[:, :-1])  # (b, nseq-1) # either this token or next token is padding\n",
    "        acts_A = orig_sae_A.encode(resid_A)[..., feature_subset]\n",
    "        acts_B = random_sae.encode(resid_B)[..., feature_subset]\n",
    "        acts_A = acts_A[~ignore_mask]\n",
    "        acts_B = acts_B[~ignore_mask]\n",
    "        next_tokens = next_tokens[~ignore_mask]\n",
    "        all_scores_A.append((acts_A * logit_weights_A[:, next_tokens].T).cpu().numpy())\n",
    "        all_scores_B.append((acts_B * logit_weights_B[:, next_tokens].T).cpu().numpy())\n",
    "\n",
    "    scores_A = np.concatenate(all_scores_A)\n",
    "    scores_B = np.concatenate(all_scores_B)\n",
    "\n",
    "    corrs_attribution = []\n",
    "    print(\"Computing correlations\")\n",
    "    for j in tqdm(range(scores_A.shape[1])):\n",
    "        pearson_res_attribution = pearsonr(scores_A[:,j], scores_B[:,j], axis=0)\n",
    "        corrs_attribution.append(pearson_res_attribution.statistic)\n",
    "    results_dict = {\n",
    "        'attribution_correlation': np.array(corrs_attribution)\n",
    "    }\n",
    "    np.savez(os.path.join(activations_dir, f'random_metrics_size_{feature_batch_size}_batch_{i}.npz'), **results_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "657843a9-7490-48bb-805f-bef1a17a9cde",
   "metadata": {},
   "source": [
    "# Part 3: Pull in batched feature activationis one by one and compute Pearson."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13307606-4257-440a-a22d-91af8b8d8a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "stitched = load_activation_store('metrics', 1000, 16384, subdir=f\"fr_fixed_attribution_correlation/{run_name_id}/\") #f\"eleuther/{run_name_id}/\"\n",
    "random_baseline = load_activation_store('metrics', 1000, 16384, subdir=f\"fr_fixed_attribution_correlation/{run_name_id}/random/\") #f\"eleuther/{run_name_id}/\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a2dfab-2077-44b7-8c78-06a6fd0275e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(random_baseline['attribution_correlation'], binrange=(-1, 1), binwidth=0.02, label='random SAE', log=True, alpha=0.5)\n",
    "sns.histplot(stitched['attribution_correlation'], binrange=(-1, 1), binwidth=0.02, label='stitched', log=True, alpha=0.5)\n",
    "plt.legend()\n",
    "plt.title('gemma')\n",
    "plt.xlabel('correlation')\n",
    "#plt.savefig('results/figures/gemma_correlations.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2e4722b-333d-4234-a4da-46d6d6e4570b",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dict = torch.load(\"gemma-2-2b.20_structural.pt\", weights_only=True)\n",
    "structural_features, densities = res_dict['structural'], res_dict['densities']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f262fc85-b234-4067-95a3-9a1acd364272",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(stitched['attribution_correlation'][structural_features], bins=20)\n",
    "plt.hist(stitched['attribution_correlation'][~structural_features], bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2b26980-4832-4694-a522-c1f3b11cf509",
   "metadata": {},
   "outputs": [],
   "source": [
    "bins = 50\n",
    "key = 'attribution_correlation'\n",
    "binrange = (-1, 1)\n",
    "not_dead = densities > 1e-6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c12c5e87-dfd1-44b0-b177-679410df3965",
   "metadata": {},
   "outputs": [],
   "source": [
    "hist_orig, bin_edges = np.histogram(stitched[key][not_dead], bins=bins, range=binrange)\n",
    "hist_good, bin_edges = np.histogram(stitched[key][structural_features], bins=bins, range=binrange, density=True)\n",
    "#hist_other, bin_edges = np.histogram(stitched['activation_correlation'][torch.logical_and(structural_features, torch.logical_or(high_density, mid_density))], bins=bins, range=(0,1))\n",
    "hist_semantic, bin_edges = np.histogram(stitched[key][~structural_features], bins=bins, range=binrange, density=True)\n",
    "#hist_trash, bin_edges = np.histogram(stitched['activation_correlation'][torch.logical_and(~structural_features, torch.logical_or(high_density, mid_density))], bins=bins, range=(0,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c822e225-f2e6-4c52-b382-d70a7c2845c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "alives = stitched[key][not_dead]\n",
    "structural_spearmans = alives[structural_features[not_dead]]\n",
    "semantic_spearmans = alives[~structural_features[not_dead]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87cb6ef2-bdc7-4e2a-80b9-4bbdec2c8549",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the JointGrid\n",
    "g = sns.JointGrid()\n",
    "\n",
    "# Plot the joint and only the x marginal\n",
    "g.plot_joint(sns.histplot, data=semantic_spearmans, binrange=binrange, bins=bins, stat='density', label='semantic', alpha=0.5)\n",
    "g.plot_joint(sns.histplot, data=structural_spearmans, binrange=binrange, bins=bins, stat='density', label='structural', alpha=0.5)\n",
    "\n",
    "#g.ax_marg_x.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_good / hist_orig, 0), color='tab:orange', label='structural')\n",
    "\n",
    "g.ax_marg_x.fill_between((bin_edges[:-1] + bin_edges[1:]) / 2, 0, np.nan_to_num(hist_good / hist_semantic, 0), color='tab:orange',alpha=0.5)\n",
    "g.ax_marg_x.fill_between((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_good / hist_semantic, 0), 1000, color='tab:blue',alpha=0.5)\n",
    "g.ax_marg_x.set(ylabel='fraction', ylim=[0,4.0])\n",
    "#g.ax_marg_x.axhline(0.5, color='black', alpha=0.5, linestyle='dashed')\n",
    "\n",
    "#ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_other / hist_orig, 0), color='tab:orange', alpha=ax2_alpha,label='structural_hi_density')\n",
    "#ax2.plot((bin_edges[:-1] + bin_edges[1:]) / 2, np.nan_to_num(hist_semantic / hist_orig, 0), color='tab:green', alpha=ax2_alpha,label='semantic')\n",
    "# Hide the y marginal\n",
    "g.ax_marg_y.remove()\n",
    "# Shrink the height of the top marginal plot\n",
    "box = g.ax_marg_x.get_position()\n",
    "g.ax_marg_x.set_position([box.x0, box.y0, box.width, box.height * 1.0])\n",
    "g.ax_marg_x.yaxis.set_visible(True)\n",
    "g.ax_marg_x.grid(True, color='gray')\n",
    "g.ax_marg_x.spines['left'].set_visible(True)\n",
    "g.ax_marg_x.tick_params(axis='y', which='both', left=True, labelleft=True)  # Turn on ticks & labels\n",
    "g.ax_marg_x.set_ylabel(\"frac\")  # Optional\n",
    "g.ax_joint.set_xlabel('correlation')\n",
    "g.ax_joint.set_ylabel('density')\n",
    "g.ax_joint.set_xlim(binrange)\n",
    "g.ax_joint.legend()\n",
    "plt.savefig('results/figures/gemma_transfer_difference.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb5cd578-cfbd-48a4-9dbe-5acbe963f036",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
