{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import sys\n",
    "import numpy as np\n",
    "sys.path.append(\"..\")\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, operators, utils, functional, metrics, lens\n",
    "from src.utils import logging_utils\n",
    "import logging\n",
    "import torch\n",
    "import baukit\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format = logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = models.load_model(name = \"gptj\", fp16 = True, device = \"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = data.load_dataset().filter(relation_names=[\"country capital city\"])[0].set(prompt_templates=[\" {}:\"])\n",
    "train, test = relation.split(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icl_prompt = functional.make_prompt(\n",
    "    prompt_template = train.prompt_templates[0],\n",
    "    subject = \"{}\",\n",
    "    examples = train.samples,\n",
    "    mt = mt\n",
    ")\n",
    "print(icl_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt = mt, \n",
    "    test_relation=test,\n",
    "    prompt_template = icl_prompt,\n",
    "    batch_size=4\n",
    ")\n",
    "len(test.samples)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Current Method => Calculate $b_r$ and $W_r$ individually and average them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = operators.JacobianIclMeanEstimator(\n",
    "    mt = mt, h_layer=7, beta=0.2\n",
    ")\n",
    "operator = estimator(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate faithfulness\n",
    "def evaluate_operator(operator, test_samples):\n",
    "    pred_objects = []\n",
    "    test_objects = []\n",
    "    for sample in test_samples:\n",
    "        test_objects.append(sample.object)\n",
    "        preds = operator(sample.subject, k=3)\n",
    "        pred = str(preds.predictions[0])\n",
    "        print(f\"{sample.subject=} -> {sample.object=} | {pred=}\")\n",
    "        pred_objects.append([p.token for p in preds.predictions])\n",
    "\n",
    "    recall = metrics.recall(pred_objects, test_objects)\n",
    "    return recall\n",
    "\n",
    "evaluate_operator(operator, test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lens.logit_lens(mt = mt, h = operator.bias, get_proba=True, k = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models.determine_layer_paths(mt, layers=[\"emb\", \"ln_f\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### At `ln_f`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator_lnf = operators.JacobianIclMeanEstimator(\n",
    "    mt = mt, \n",
    "    h_layer=7, \n",
    "    z_layer=\"ln_f\", \n",
    "    beta=1\n",
    ")\n",
    "operator_lnf = estimator_lnf(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_operator(operator_lnf, test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lens.logit_lens(mt = mt, h = operator_lnf.bias, get_proba=True, k = 10, after_layer_norm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# operator_dct = deepcopy(operator.__dict__)\n",
    "\n",
    "# set beta and omega such that beta/omega = c\n",
    "c = 0.2\n",
    "beta = 0.1\n",
    "# omega = beta/c\n",
    "omega = 1\n",
    "\n",
    "# omega = 5\n",
    "# beta = omega * c\n",
    "\n",
    "print(f\"{beta=} | {omega=}\")   \n",
    "\n",
    "operator_dct = operator_lnf.__dict__.copy()\n",
    "operator_dct[\"beta\"] = beta\n",
    "operator_dct[\"weight\"] = operator.weight * omega\n",
    "operator_no_beta_lnf = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_no_beta_lnf, test.samples)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get rid of the $\\beta$ by setting $\\beta = 1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# operator_dct = deepcopy(operator.__dict__)\n",
    "\n",
    "# set beta and omega such that beta/omega = c\n",
    "c = 0.2\n",
    "beta = 0.0002\n",
    "omega = beta/c\n",
    "\n",
    "omega = 5\n",
    "beta = omega * c\n",
    "\n",
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = beta\n",
    "operator_dct[\"weight\"] = operator.weight * omega\n",
    "operator_no_beta = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_no_beta, test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "1 / operator_dct[\"beta\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = train.prompt_templates[0],\n",
    "    subjects = [sample.subject for sample in relation.samples],\n",
    "    h_layer= operator.h_layer,\n",
    "    z_layer=-1,\n",
    "    batch_size=4,\n",
    "    examples= train.samples\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for sample in train.samples:\n",
    "    subj = sample.subject\n",
    "    obj = sample.object\n",
    "    pred = functional.predict_next_token(\n",
    "        mt = mt,\n",
    "        prompt = functional.make_prompt(\n",
    "            prompt_template = train.prompt_templates[0],\n",
    "            subject = subj,\n",
    "            examples = train.samples,\n",
    "            mt = mt\n",
    "        )\n",
    "    )[0][0]\n",
    "    h_norm = hs_and_zs.h_by_subj[subj].norm().item()\n",
    "    z_norm = hs_and_zs.z_by_subj[subj].norm().item()\n",
    "    print(f\"{subj=} -> {obj=} | {h_norm=} | {z_norm=} || {pred=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h_norms = []\n",
    "jh_norms = []\n",
    "for subj in hs_and_zs.h_by_subj.keys():\n",
    "    h = hs_and_zs.h_by_subj[subj]\n",
    "    jh = operator.weight @ h\n",
    "    h_norms.append(h.norm().item())\n",
    "    jh_norms.append(jh.norm().item())\n",
    "    print(f\"{subj=} | {h.norm()=} | {jh.norm()=} | {h.mean()=} | {h.std()=}\")\n",
    "\n",
    "print(f\"h_norms: {np.mean(h_norms):.2f} +/- {np.std(h_norms):.2f}\")\n",
    "print(f\"jh_norms: {np.mean(jh_norms):.2f} +/- {np.std(jh_norms):.2f}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### $b_r = \\mathbf{o}_{mean} - J\\mathbf{s}_{mean}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_mean = torch.stack([hs_and_zs.z_by_subj[sample.subject] for sample in train.samples]).mean(dim = 0)\n",
    "h_mean = torch.stack([hs_and_zs.h_by_subj[sample.subject] for sample in train.samples]).mean(dim = 0)\n",
    "\n",
    "bias_mean = z_mean - operator.weight @ h_mean\n",
    "print(torch.dist(bias_mean, operator.bias))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = 1 #.2\n",
    "operator_dct[\"bias\"] = bias_mean\n",
    "operator_bias_J = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_bias_J, test.samples)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### $b_r = F(\\mathbf{s}_{mean}) -J\\mathbf{s}_{mean}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h_layer_name, z_layer_name = models.determine_layer_paths(mt, layers = [operator.h_layer, operator.z_layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs.h_by_subj.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h\n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "subject = \"Russia\"\n",
    "prompt = icl_prompt.format(subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=subject,\n",
    ")\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer_name, z_layer_name],\n",
    "    edit_output=get_intervention(h_mean, h_layer_name, h_index)\n",
    ") as traces:\n",
    "    outputs = mt.model(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = outputs.logits[0][-1], \n",
    "    get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = functional.untuple(traces[h_layer_name].output)[0][h_index]\n",
    "s.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s.norm().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h_mean.norm().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_mean_F = traces[z_layer_name].output[0][-1][-1]\n",
    "# lens.logit_lens(mt = mt, h = z_mean_F, get_proba=True)\n",
    "bias_F = z_mean_F - operator.weight @ h_mean\n",
    "print(torch.dist(bias_F, operator.bias))\n",
    "lens.logit_lens(mt = mt, h = bias_F, get_proba=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = 1 #.2\n",
    "operator_dct[\"bias\"] = bias_F\n",
    "operator_bias_J = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_bias_J, test.samples)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make all the $|\\mathbf{o}|$ similar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_s = torch.stack([hs_and_zs.z_by_subj[sample.subject] for sample in train.samples])\n",
    "min_norm = z_s.norm(dim = 1).min()\n",
    "z_s = torch.stack([(z*min_norm)/z.norm() for z in z_s])\n",
    "\n",
    "z_mean = z_s.mean(dim = 0)\n",
    "bias_mean = z_mean - operator.weight @ h_mean\n",
    "\n",
    "print(torch.dist(bias_mean, operator.bias))\n",
    "lens.logit_lens(mt = mt, h = bias_mean, get_proba=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = 1 #.2\n",
    "operator_dct[\"bias\"] = bias_mean\n",
    "operator_similar_o = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_similar_o, test.samples)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Automatically *tune* $\\beta$ for each training sample "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_hs = torch.stack([hs_and_zs.h_by_subj[sample.subject] for sample in train.samples])\n",
    "training_zs = torch.stack([hs_and_zs.z_by_subj[sample.subject] for sample in train.samples])\n",
    "\n",
    "biases = []\n",
    "for sample in train.samples:\n",
    "    subj = sample.subject\n",
    "    obj = sample.object\n",
    "    h = hs_and_zs.h_by_subj[subj]\n",
    "    z = hs_and_zs.z_by_subj[subj]\n",
    "    b_sample = z - operator.weight @ h\n",
    "    print(f\"{subj=} | h_norm={h.norm()} | z_norm={z.norm()} || b_norm={b_sample.norm()}\")\n",
    "    for beta in np.linspace(0, 1, 10):\n",
    "        z_est = operator.weight @ h + b_sample * beta\n",
    "        pred, _ = lens.logit_lens(mt = mt, h = z_est, get_proba=True, k = 3)\n",
    "        print(f\"{obj=} | {beta=} | z_est={z_est.norm()} | {pred=}\")\n",
    "        top_token = pred[0][0]\n",
    "        if functional.is_nontrivial_prefix(prediction=top_token, target=sample.object):\n",
    "            biases.append(b_sample * beta)\n",
    "            break\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator = estimator(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "biases = []\n",
    "for sample, approx in zip(train.samples, operator.metadata[\"approxes\"]):\n",
    "    subj = sample.subject\n",
    "    obj = sample.object\n",
    "    h = approx.h\n",
    "    z = approx.z\n",
    "    print(f\"{subj=} | h_norm={h.norm()} | z_norm={z.norm()} || b_norm={approx.bias.norm()}\")\n",
    "    for beta in np.linspace(0, 1, 10):\n",
    "        z_est = approx.weight @ h + approx.bias * beta\n",
    "        pred, _ = lens.logit_lens(mt = mt, h = z_est, get_proba=True, k = 3)\n",
    "        print(f\"{obj=} | {beta=} | z_est={z_est.norm()} | {pred=}\")\n",
    "        top_token = pred[0][0]\n",
    "        if functional.is_nontrivial_prefix(prediction=top_token, target=sample.object):\n",
    "            biases.append(b_sample * beta)\n",
    "            break\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_mean = torch.stack(biases).mean(dim = 0)\n",
    "\n",
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = 1 #.2\n",
    "operator_dct[\"bias\"] = b_mean\n",
    "operator_auto_beta = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_auto_beta, test.samples)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drop $J\\mathbf{s}$ entirely from bias estimation. (basically the corner method?)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "biases = torch.stack([approx.bias[0] for approx in operator.metadata[\"approxes\"]])\n",
    "min_norm = biases.norm(dim = 1).min()\n",
    "\n",
    "biases = torch.stack([(b*min_norm)/b.norm() for b in biases])\n",
    "b_mean = biases.mean(dim = 0)\n",
    "\n",
    "print(torch.dist(b_mean, operator.bias))\n",
    "lens.logit_lens(mt = mt, h = b_mean, get_proba=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_mean.norm().item(), operator.bias.norm().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = .2 #.2\n",
    "operator_dct[\"bias\"] = b_mean\n",
    "operator_dropped_js = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_dropped_js, test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "translation_estimator = operators.OffsetEstimatorBaseline(\n",
    "    mt = mt,\n",
    "    h_layer = operator.h_layer,\n",
    "    z_layer=operator.z_layer,\n",
    "    mode = \"icl\",\n",
    ")\n",
    "\n",
    "translation = translation_estimator(\n",
    "    relation.set(\n",
    "        samples = train.samples + test.samples\n",
    "    )\n",
    ")\n",
    "corner = translation.bias\n",
    "\n",
    "print(torch.dist(corner, operator.bias))\n",
    "lens.logit_lens(mt = mt, h = corner, get_proba=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_operator(translation, test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dct = operator.__dict__.copy()\n",
    "operator_dct[\"beta\"] = 1 #.2\n",
    "operator_dct[\"bias\"] = corner\n",
    "operator_corner = operators.LinearRelationOperator(**operator_dct)\n",
    "\n",
    "evaluate_operator(operator_corner, test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.explain_beta import TrialResult, AllTrialResults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_path = \"../results/explain_beta/gptj\"\n",
    "\n",
    "results = {}\n",
    "for relation_folder in os.listdir(beta_path):\n",
    "    relation_path = os.path.join(beta_path, relation_folder)\n",
    "    for n_train in os.listdir(relation_path):\n",
    "        n_train_path = os.path.join(relation_path, n_train)\n",
    "        for file in os.listdir(n_train_path):\n",
    "            with open(f\"{n_train_path}/{file}\") as f:\n",
    "                data = json.load(f)\n",
    "                data = AllTrialResults.from_dict(data)\n",
    "                if(data.relation_name not in results):\n",
    "                    results[data.relation_name] = {}\n",
    "                _n_train = len(data.trials[0].train_samples)\n",
    "                if _n_train not in results[data.relation_name]:\n",
    "                    results[data.relation_name][_n_train] = data                   \n",
    "                else:\n",
    "                    results[data.relation_name][_n_train].trials.extend(data.trials)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relations = list(results.keys())\n",
    "relation = relations[0]\n",
    "# relations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_options = sorted(list(results[relation].keys()))\n",
    "train_options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bias_norms = [\n",
    "    np.array([trial.bias_norm for trial in results[relation][n_train].trials])\n",
    "    for n_train in train_options\n",
    "]\n",
    "\n",
    "means = np.array([np.mean(bias_norm) for bias_norm in bias_norms])\n",
    "stds = np.array([np.std(bias_norm) for bias_norm in bias_norms])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcdefaults()\n",
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 18\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+5)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "plt.plot(train_options, means)\n",
    "plt.fill_between(train_options, means - stds, means + stds, alpha=0.2)\n",
    "# plt.ylim(bottom = 270)\n",
    "plt.xticks(train_options)\n",
    "plt.xlabel(\"n_train\")\n",
    "plt.ylabel(\"bias norm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
