{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cd .. && pip install -e ./../nnpatch ./../pycolors -e ./../pyvene && pip install -U transformers kaleido && pip install circuitsvis python-dotenv --no-deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from nnsight import NNsight\n",
    "import torch\n",
    "import os\n",
    "from tqdm.notebook import tqdm, trange\n",
    "\n",
    "from nnsight import NNsight\n",
    "\n",
    "from analysis.circuit_utils.visualisation import *\n",
    "from analysis.circuit_utils.model import *\n",
    "from analysis.circuit_utils.validation import *\n",
    "from analysis.circuit_utils.decoding import *\n",
    "from analysis.circuit_utils.utils import *\n",
    "from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch\n",
    "from analysis.circuit_utils.das import *\n",
    "\n",
    "from main import load_model_and_tokenizer\n",
    "\n",
    "\n",
    "from nnpatch.api.mistral import Mistral\n",
    "\n",
    "jupyter_enable_mathjax()\n",
    "\n",
    "plot_dir = \"plots/Mistral-7B-Instruct-v0.3\"\n",
    "MODEL_STORE=\"/dlabscratch1/public/llm_weights/mistral_hf/\"\n",
    "PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf=\"instruction\", model_id=\"Mistral-7B-Instruct-v0.3\", model_store=MODEL_STORE, n_samples=100)\n",
    "os.makedirs(plot_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "tmp = torch.load(\"/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/analysis/ctxprior-gemma-2-9b-it-L27.pt\")\n",
    "tmp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python analysis/scripts/merge_model.py --model-id Mistral-7B-Instruct-v0.3 --model-store /dlabscratch1/public/llm_weights/mistral_hf/ --cwf instruction --dataset BaseFakepedia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)\n",
    "nnmodel = NNsight(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATHS, args = get_decoding_args(finetuned=True, no_filtering=True, load_in_4bit=True, cwf=\"instruction\", model_id=\"Mistral-7B-Instruct-v0.3\", model_store=MODEL_STORE, n_samples=200)\n",
    "all_tokens, all_attn_mask, context_1_tokens, context_2_tokens, context_3_tokens, prior_1_tokens, prior_2_tokens, context_1_attention_mask, context_2_attention_mask, context_3_attention_mask, prior_1_attention_mask, prior_2_attention_mask, context_1_answer, context_2_answer, context_3_answer, prior_1_answer, prior_2_answer = get_data(args, PATHS, tokenizer)\n",
    "\n",
    "\n",
    "prior_args = [all_tokens, all_attn_mask, prior_1_tokens, prior_2_tokens, prior_1_attention_mask, prior_2_attention_mask, prior_1_answer, prior_2_answer]\n",
    "ctx_args = [all_tokens, all_attn_mask, context_1_tokens, context_2_tokens, context_1_attention_mask, context_2_attention_mask, context_1_answer, context_2_answer]\n",
    "cp_args = [all_tokens, all_attn_mask, context_1_tokens, prior_1_tokens, context_1_attention_mask, prior_1_attention_mask, context_1_answer, prior_1_answer]\n",
    "pc_args = [all_tokens, all_attn_mask, prior_1_tokens, context_1_tokens, prior_1_attention_mask, context_1_attention_mask, prior_1_answer, context_1_answer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.decode(prior_1_tokens[0], skip_special_tokens=False)), print(tokenizer.decode(prior_1_answer[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PRIOR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nnpatch.api.mistral import Mistral\n",
    "\n",
    "prior_range = auto_search(model, tokenizer, prior_args, n_layers=32, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Mistral, lower_bound=13, upper_bound=19)\n",
    "print(prior_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx_range = auto_search(model, tokenizer, ctx_args, n_layers=32, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Mistral)\n",
    "print(ctx_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cp_range = auto_search(model, tokenizer, cp_args, n_layers=32, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Mistral)\n",
    "print(cp_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pc_range = auto_search(model, tokenizer, pc_args, n_layers=32, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Mistral)\n",
    "print(pc_range)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { # PRIOR   \n",
    "}\n",
    "\n",
    "figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Mistral, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { # PRIOR   \n",
    "}\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Mistral, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": [13, 14, 15, 16, 17, 18, 19, 21, 24, 29]\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Mistral, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(28, 32)) \n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=10, output_dir=plot_dir, api=Mistral, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": [19, 20, 21, 22, 28, 29, 30, 31]\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=32, batch_size=10, output_dir=plot_dir, api=Mistral, title=generate_title(site_1_config, \"CTX - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": [8, 9, 10, 11, 12, 13, 14, 15, 16, 24, 28, 29]\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *cp_args, site_1_config, N_LAYERS=32, batch_size=10, output_dir=\"plots/Mistral-7B-Instruct-v0.3\", api=Mistral, title=generate_title(site_1_config, \"CP - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": [8, 9, 10, 11, 12, 13, 14, 15, 16, 24, 28, 29, 31]\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=32, batch_size=10, output_dir=\"plots/Mistral-7B-Instruct-v0.3\", api=Mistral, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { # PRIOR\n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(25, 30)),\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { # PRIOR\n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(25, 30)) + [37],\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { # PRIOR\n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(25, 30)) + [37, 40],\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { # PRIOR\n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(28, 42)),\n",
    "    },\n",
    "}\n",
    "\n",
    "figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PRIOR - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "}\n",
    "figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"CTX - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(25, 30)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=42, batch_size=10, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"CTX - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(29, 42)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"CTX - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "}\n",
    "figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"CP - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(0, 28)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"CP - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(20, 28)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"CP - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(20, 28)) + [37, 40]\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(20, 28)) + [37]\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(20, 28))\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=8, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(0, 28)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(20, 30)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "site_1_config = { \n",
    "    \"o\":\n",
    "    {\n",
    "        \"layers\": list(range(25, 30)),\n",
    "    },\n",
    "}\n",
    "figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir=\"plots/gemma2-9b-it\", api=Gemma2, title=generate_title(site_1_config, \"PC - \"))\n",
    "figp.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from analysis.circuit_utils.das import *\n",
    "from functools import partial\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from nnsight import NNsight\n",
    "import torch\n",
    "import os\n",
    "from tqdm.notebook import tqdm, trange\n",
    "\n",
    "from nnsight import NNsight\n",
    "\n",
    "from analysis.circuit_utils.visualisation import *\n",
    "from analysis.circuit_utils.model import *\n",
    "from analysis.circuit_utils.validation import *\n",
    "from analysis.circuit_utils.decoding import *\n",
    "from analysis.circuit_utils.utils import *\n",
    "from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch\n",
    "\n",
    "from main import load_model_and_tokenizer\n",
    "from nnpatch.subspace.interventions import train_projection, create_dataset, LowRankOrthogonalProjection\n",
    "\n",
    "\n",
    "from nnpatch.api.mistral import Mistral\n",
    "\n",
    "jupyter_enable_mathjax()\n",
    "\n",
    "plot_dir = \"plots/Mistral-7B-Instruct-v0.3\"\n",
    "MODEL_STORE=\"/dlabscratch1/public/llm_weights/mistral_hf/\"\n",
    "os.makedirs(plot_dir, exist_ok=True)\n",
    "\n",
    "device = \"cuda:0\"\n",
    "\n",
    "PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf=\"instruction\", model_id=\"Mistral-7B-Instruct-v0.3\", model_store=MODEL_STORE, n_samples=1000, no_filtering=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "st, tt, si, ti, ams, amt, tit, amti = prepare_train_data(args, PATHS, tokenizer, device, same_query=True, remove_weight=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confident_indices = filter_confident_samples(args, model, tt, tit, ti, si, amt, amti, batch_size=32)\n",
    "train_dataset = create_dataset(st[confident_indices], tt[confident_indices], si[confident_indices], ti[confident_indices], ams[confident_indices], amt[confident_indices])\n",
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_prompt, target_prompt, source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask = collect_data(args, PATHS, tokenizer, \"cuda\")\n",
    "test_dataset = create_dataset(source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask)\n",
    "test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# proj = LowRankOrthogonalProjection.load_pretrained(\"analysis/results_das/Mistral-7B-Instruct-v0.3/Mistral-7B-Instruct-v0.3-L16.pt\")\n",
    "proj = LowRankOrthogonalProjection(embed_dim=4096, rank=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proj, projection = train_projection(model, proj, layer=16, train_dataset=train_dataset, val_dataset=test_dataset, epochs=1, batch_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proj.save_pretrained(\"projections/Mistral-7B-Instruct-v0.3-L16\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "default",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
