{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "908907bb-6b23-4530-99ee-b297256936f9",
   "metadata": {},
   "source": [
    "# Activations visulaized - 1\n",
    "\n",
    "Gets model(loras) checkpoint and runs data thru it to capture acts, grads and LoRA grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963e59ff-24c1-414a-ae30-afa29b1bdfc6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env OMP_NUM_THREADS=16 \n",
    "%env MKL_NUM_THREADS=16 \n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb1a542-8260-47ee-9ee0-b2248692aea6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys, pathlib, os\n",
    "sys.path.append(str(pathlib.Path('./src').resolve()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998dec4b-a244-4c59-8a4f-c946bd2a237e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import transformers\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "print(f\"{torch.__version__=}, {transformers.__version__=}, {device=}\")\n",
    "\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32a67d6-2fe7-4029-b0f0-5481a6adb2a8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from utils_data import get_data, preprocess_datasets\n",
    "from utils_trainer import TrainerWithMetrics\n",
    "from utils_model import (\n",
    "    dc_regularizing_loss,\n",
    "    get_fitted_logreg,\n",
    "    get_base_model,\n",
    "    get_tokenizer,\n",
    "    ModelWithMultipleLoras,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60fd1e4f-49a6-4b56-8416-a087420ff890",
   "metadata": {},
   "source": [
    "## loading data and model\n",
    "code from `main.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc066a9-070a-4b31-9eda-9e5c8cff8824",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_args, data_args, training_args = torch.load(\"baseline_args.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ccc70de-0005-41fd-a4ec-e7545a6054b5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "raw_datasets, is_regression, label_list = get_data(\n",
    "    model_args, data_args, training_args\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c519816-2146-4c80-9894-3662307782ae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = get_base_model(\n",
    "    model_args, finetuning_task=data_args.task_name, num_labels=len(label_list),\n",
    ")\n",
    "\n",
    "tokenizer = get_tokenizer(model_args)\n",
    "\n",
    "train_dataset, eval_dataset, predict_dataset, raw_datasets = preprocess_datasets(\n",
    "    raw_datasets,\n",
    "    data_args,\n",
    "    training_args,\n",
    "    model,\n",
    "    tokenizer,\n",
    "    label_list,\n",
    "    is_regression,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "873ee9b8-dd7a-4ad9-8aa1-fa696317492d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if data_args.pad_to_max_length:\n",
    "    data_collator = transformers.default_data_collator\n",
    "elif training_args.fp16:\n",
    "    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "else:\n",
    "    data_collator = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "108306e0-b04a-49b5-8ca3-e12cb99e5b8d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class ModelWithMultipleLoras_2(ModelWithMultipleLoras):\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels):\n",
    "        \n",
    "        self.saved = [] # !!!\n",
    "        \n",
    "        rank = 0\n",
    "\n",
    "        if self.n_of_loras and (self.shift_lr_rw or self.shift_dc_rw or not self.model.training):\n",
    "            with torch.random.fork_rng(\n",
    "                devices=(torch.device(\"cpu\"), self.device), enabled=True\n",
    "            ):\n",
    "                baseline_activations = self._choose_adapter_and_forward(\n",
    "                    -1, input_ids, attention_mask\n",
    "                )\n",
    "                if self.model_type == \"deberta\":\n",
    "                    baseline_activations = baseline_activations[:, 0]\n",
    "        different_loras_activations = []\n",
    "        for i in range(max(self.n_of_loras, 1)):\n",
    "            if i < self.n_of_loras - 1:\n",
    "                with torch.random.fork_rng(\n",
    "                    devices=(torch.device(\"cpu\"), self.device), enabled=True\n",
    "                ):\n",
    "                    activations = self._choose_adapter_and_forward(\n",
    "                        i, input_ids, attention_mask\n",
    "                    )\n",
    "            else:\n",
    "                activations = self._choose_adapter_and_forward(\n",
    "                    i, input_ids, attention_mask\n",
    "                )\n",
    "            different_loras_activations.append(activations)\n",
    "            \n",
    "        activations = self.get_head_input(different_loras_activations, self.coefs)\n",
    "\n",
    "        self.saved.append(activations) # !!!\n",
    "\n",
    "        logits = self.classifier(activations)\n",
    "        loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))\n",
    "        if self.model_type == \"deberta\":\n",
    "            for i in range(len(different_loras_activations)):\n",
    "                different_loras_activations[i] = different_loras_activations[i][:, 0]\n",
    "\n",
    "        if self.activation_lr_rw or self.activation_dc_rw:\n",
    "            for cur_activation in different_loras_activations:\n",
    "                if not rank:\n",
    "                    if self.activation_lr_rw:\n",
    "                        if len(torch.unique(labels)) < self.num_labels:\n",
    "                            continue\n",
    "                        logreg = get_fitted_logreg(\n",
    "                            cur_activation.detach().cpu().numpy(),\n",
    "                            labels.cpu().numpy(),\n",
    "                            seed=self.seed,\n",
    "                        )\n",
    "                        loss += (\n",
    "                            self.regularizing_logreg_loss(\n",
    "                                cur_activation,\n",
    "                                labels,\n",
    "                                logreg,\n",
    "                                neck_width=self.neck_width,\n",
    "                                device=self.device,\n",
    "                                dtype=self.model.dtype\n",
    "                            ) * self.activation_lr_rw\n",
    "                        )\n",
    "                    if self.activation_dc_rw:\n",
    "                        loss += (self.activation_dc_rw \n",
    "                                * dc_regularizing_loss(cur_activation, labels))\n",
    "                else:\n",
    "                    loss += cur_activation.norm()\n",
    "\n",
    "        different_loras_shifts = []\n",
    "        if self.n_of_loras and (self.shift_lr_rw or self.shift_dc_rw or not self.model.training):\n",
    "            for cur_activation in different_loras_activations:\n",
    "                cur_shift = cur_activation - baseline_activations\n",
    "                different_loras_shifts.append(cur_shift)\n",
    "                if not rank:\n",
    "                    if self.shift_lr_rw:\n",
    "                        if len(torch.unique(labels)) < self.num_labels:\n",
    "                            continue\n",
    "                        logreg = get_fitted_logreg(\n",
    "                            cur_shift.detach().cpu().numpy(),\n",
    "                            labels.cpu().numpy(),\n",
    "                            seed=self.seed,\n",
    "                        )\n",
    "                        loss += (\n",
    "                            self.regularizing_logreg_loss(\n",
    "                                cur_shift,\n",
    "                                labels,\n",
    "                                logreg,\n",
    "                                neck_width=self.neck_width,\n",
    "                                device=self.device,\n",
    "                                dtype=self.model.dtype\n",
    "                            ) * self.shift_lr_rw\n",
    "                        )\n",
    "                    if self.shift_dc_rw:\n",
    "                        loss += (self.shift_dc_rw \n",
    "                                * dc_regularizing_loss(cur_shift, labels))\n",
    "                else:\n",
    "                    loss += cur_shift.norm()\n",
    "\n",
    "        return (\n",
    "            loss,\n",
    "            logits,\n",
    "            [cur_acts for cur_acts in different_loras_activations],\n",
    "            different_loras_shifts,\n",
    "            activations,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb9e734-5e85-433d-ba03-60b0af36ae2b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_multiple_loras = ModelWithMultipleLoras_2(\n",
    "        base_model=model,\n",
    "        num_labels=2,\n",
    "        model_type='deberta',\n",
    "        n_of_loras=model_args.n_of_loras,\n",
    "        lora_rank=model_args.lora_rank,\n",
    "        device=training_args.device,\n",
    "        lora_alpha=model_args.lora_alpha,\n",
    "        lora_dropout=model_args.lora_dropout,\n",
    "        seed=training_args.seed,\n",
    "        mult_std=model_args.mult_std,\n",
    "        method_name=model_args.coefs_method_name,\n",
    "        activation_lr_rw=model_args.activation_lr_rw,\n",
    "        shift_lr_rw=model_args.shift_lr_rw,\n",
    "        activation_dc_rw=model_args.activation_dc_rw,\n",
    "        shift_dc_rw=model_args.shift_dc_rw,\n",
    "        loras_gradient_checkpointing=model_args.loras_gradient_checkpointing,\n",
    "        model_gradient_checkpointing=model_args.model_gradient_checkpointing,\n",
    "    ).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84fabe36-9869-4193-a75d-17347c4d5174",
   "metadata": {},
   "source": [
    "## Loading checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142efa16-003b-4ce3-9ed7-136ab7f14a7f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "checkpoints_path = \"./deberta_sst2/checkpoints\"\n",
    "sorted(os.listdir(checkpoints_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a5a6656-641c-4bda-ac7c-57be2763c926",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# insert step here\n",
    "step = 16000\n",
    "filename = os.path.join(checkpoints_path, f\"checkpoint_{step}.pt\")\n",
    "cp = torch.load(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d67dab-9ef3-4653-bb77-00d7593ce66a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def update_state_dict_from_checkpoint(self, checkpoint_state_dict):\n",
    "    # updates ModelWithMultipleLoras with params from checkpoint\n",
    "    # only takes params present in CP, keeps rest intact\n",
    "    # assuming that the CP has just the loras and head where `requires_grad==True`\n",
    "    print(f\"checkpoint contains {len(checkpoint_state_dict)} modules\")\n",
    "    sd0 = model_multiple_loras.state_dict()\n",
    "    print(f\"state_dict contains {len(sd0)} modules\")\n",
    "    counter = 0 \n",
    "    for n, p in self.state_dict().items():    \n",
    "        if n in checkpoint_state_dict.keys():\n",
    "            sd0[n] = checkpoint_state_dict[n]\n",
    "            counter += 1\n",
    "    self.load_state_dict(sd0)\n",
    "    print(f\"updated {counter} modules\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf0f3fa-dc65-4db5-86f1-0a897486d47e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "update_state_dict_from_checkpoint(model_multiple_loras, cp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7365dc39-659b-40d1-8ece-15f917ed79a1",
   "metadata": {},
   "source": [
    "## Run val data thru model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e095129-539c-4924-92db-c5782007d53d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lora_parameters = [p for n, p in model_multiple_loras.named_parameters() if p.requires_grad and 'lora' in n.lower()]\n",
    "lora_names = [n for n, p in model_multiple_loras.named_parameters() if p.requires_grad and 'lora' in n.lower()]\n",
    "for param in lora_parameters:\n",
    "    assert param.requires_grad\n",
    "len(lora_names), len(lora_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b441be31-95c6-499d-b9b9-cd40cc91f7bb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "bsize = 1\n",
    "training_args.per_device_eval_batch_size = bsize\n",
    "training_args.per_device_train_batch_size = bsize\n",
    "\n",
    "trainer = TrainerWithMetrics(\n",
    "    model=model_multiple_loras,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset if training_args.do_train else None,\n",
    "    eval_dataset=eval_dataset if training_args.do_eval else None,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e350926-02cc-4fc2-87ca-7935d67d13fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "eval_dataloader = trainer.get_eval_dataloader(eval_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30ff73c-9aab-45fb-b8e0-df99dca1a8e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for k_reg in [0, 1000]:\n",
    "    acts = []\n",
    "    grads = []\n",
    "    lora_grads = []\n",
    "    labels = []\n",
    "    get_loras_grads = True\n",
    "    \n",
    "    for batch in tqdm(eval_dataloader):    \n",
    "        bout = model_multiple_loras(input_ids=batch['input_ids'], \n",
    "                                    attention_mask=batch['attention_mask'],\n",
    "                                    labels = batch['labels'])\n",
    "        head_inputs = model_multiple_loras.saved[0]\n",
    "        batch_loss = bout[0]\n",
    "\n",
    "        grad_wrt_activations = torch.autograd.grad(batch_loss, head_inputs)\n",
    "\n",
    "        if get_loras_grads:\n",
    "            assert head_inputs.shape[0] == 1\n",
    "            z = torch.randn_like(grad_wrt_activations[0])    \n",
    "            grad_wrt_loras = torch.autograd.grad(\n",
    "                outputs=[head_inputs], inputs=lora_parameters,\n",
    "                grad_outputs=[grad_wrt_activations[0] + k_reg * z]\n",
    "                )\n",
    "\n",
    "        with torch.no_grad():\n",
    "            acts.append(head_inputs[:, 0].detach())\n",
    "            grads.append(grad_wrt_activations[0][:, 0])\n",
    "            labels.append(batch['labels'])\n",
    "            if get_loras_grads:\n",
    "                # flat_grad_wrt_loras = torch.concat([grad.flatten() for grad in grad_wrt_loras], dim=0)\n",
    "                # lora_grads.append(flat_grad_wrt_loras)\n",
    "                lora_grads.append(grad_wrt_loras)\n",
    "                del grad_wrt_loras #, flat_grad_wrt_loras\n",
    "            # del grad_wrt_activations, head_inputs, batch_loss\n",
    "        # labels = torch.concat(labels).cpu()\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    stacked_lora_grads = torch.stack([torch.stack([x.flatten() for x in one_datapoint]) for one_datapoint in lora_grads])\n",
    "    to_save = dict(\n",
    "        acts = torch.concat(acts).cpu().half(),\n",
    "        grads = torch.concat(grads).cpu().half(),\n",
    "        lora_grads = stacked_lora_grads.cpu().half(),\n",
    "        labels = torch.concat(labels).cpu(),\n",
    "    )\n",
    "    if \"outs\" not in os.listdir():\n",
    "        os.mkdir(os.path.join(os.getcwd(), \"outs\"))\n",
    "    torch.save(to_save, f'outs/outs_{step}_reg_{k_reg}.pt')\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d93f7c-4cf9-4ad6-99a5-22f25113f640",
   "metadata": {},
   "outputs": [],
   "source": [
    "%stop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8812fc19",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
