{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "126f9c18",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-15T22:36:39.821707Z",
     "iopub.status.busy": "2024-05-15T22:36:39.821493Z",
     "iopub.status.idle": "2024-05-15T22:36:40.263148Z",
     "shell.execute_reply": "2024-05-15T22:36:40.262556Z"
    },
    "papermill": {
     "duration": 0.445547,
     "end_time": "2024-05-15T22:36:40.264583",
     "exception": false,
     "start_time": "2024-05-15T22:36:39.819036",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import wandb\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "    \n",
    "api = wandb.Api()\n",
    "project = \"lora\"\n",
    "workspace = \"username\"\n",
    "\n",
    "# Get our two main experiments so far\n",
    "experiment_tags = [\"mt_eval\"]\n",
    "\n",
    "# get all runs that both: 1.  match any experiment tag and 2. are finished\n",
    "runs = api.runs(f\"{workspace}/{project}\",\n",
    "                {\"$and\": [\n",
    "                    {\"tags\": {\"$in\": experiment_tags}},\n",
    "                ]})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8a565d5d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-15T22:36:40.277092Z",
     "iopub.status.busy": "2024-05-15T22:36:40.276955Z",
     "iopub.status.idle": "2024-05-15T22:36:43.051536Z",
     "shell.execute_reply": "2024-05-15T22:36:43.050986Z"
    },
    "papermill": {
     "duration": 2.777311,
     "end_time": "2024-05-15T22:36:43.052797",
     "exception": false,
     "start_time": "2024-05-15T22:36:40.275486",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chiche/miniconda3/envs/lorta/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/chiche/miniconda3/envs/lorta/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n",
      "/home/chiche/miniconda3/envs/lorta/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import lightning as L\n",
    "from lit_gpt.lora import GPT,  Config\n",
    "from lit_gpt.utils import (\n",
    "    check_valid_checkpoint_dir,\n",
    "    lazy_load,\n",
    ")\n",
    "from pathlib import Path\n",
    "\n",
    "import gc\n",
    "\n",
    "import lightning as L\n",
    "from functools import partial\n",
    "\n",
    "from scripts.convert_lit_checkpoint import check_conversion_supported, copy_weights_llama, incremental_save\n",
    "\n",
    "from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights\n",
    "from lit_gpt.model import Config as ModelConfig\n",
    "from lit_gpt.utils import check_valid_checkpoint_dir,  lazy_load\n",
    "\n",
    "import contextlib\n",
    "\n",
    "from finetune.lora import validate, get_max_seq_length\n",
    "\n",
    "from lit_gpt.tokenizer import Tokenizer\n",
    "\n",
    "import torch\n",
    "\n",
    "def setup(run, checkpoints =\"checkpoints/meta-llama/Llama-2-7b-hf\",\n",
    "           adapter_path = \"downloads/\", out_dir= \"merged/\", out_name=\"model.pth\"):\n",
    "    config = run.config\n",
    "    checkpoints = Path(checkpoints)\n",
    "    conf = Config.from_name(\n",
    "            name=checkpoints.name,\n",
    "            r=config[\"lora_r\"],\n",
    "            alpha=config[\"alpha\"],\n",
    "            dropout=config[\"lora_dropout\"],\n",
    "            to_query=config[\"lora_query\"],\n",
    "            to_key=config[\"lora_key\"],\n",
    "            to_value=config[\"lora_value\"],\n",
    "            to_projection=config[\"lora_projection\"],\n",
    "            to_mlp=config[\"lora_mlp\"],\n",
    "            to_head=config[\"lora_head\"],\n",
    "            joint_qkvp=config[\"joint_qkvp\"],\n",
    "            tensor_lora=config[\"tensor_lora\"],\n",
    "        )\n",
    "    fabric = L.Fabric(devices=1, strategy=\"auto\", precision=\"bf16-true\", plugins=None)\n",
    "    fabric.seed_everything(0)  # same seed for every process to init model (FSDP)\n",
    "    \n",
    "    with fabric.init_module(empty_init=(False)):\n",
    "        model = GPT(conf)\n",
    "    #tokenizer = Tokenizer(io.checkpoint_dir)\n",
    "    check_valid_checkpoint_dir(checkpoints)\n",
    "\n",
    "    if fabric.global_rank == 0:\n",
    "        os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "    checkpoint_path = checkpoints / \"lit_model.pth\"\n",
    "    #print(adapter_path)\n",
    "    adapter = lazy_load(Path(adapter_path))\n",
    "    base = lazy_load(Path(checkpoint_path))\n",
    "    params = {**adapter.sd[\"model\"], **base.sd}\n",
    "    #print(params.keys())\n",
    "    model.load_state_dict(params, strict=True)\n",
    "    #model.eval()\n",
    "\n",
    "    print(\"merging weights\")\n",
    "    \n",
    "    merge_lora_weights(model)\n",
    "\n",
    "    print(\"evaluating\")\n",
    "\n",
    "    data_dir = Path(\"data/alpaca\")\n",
    "    val_data = torch.load(data_dir / \"test.pt\")\n",
    "    train_data = torch.load(data_dir / \"train.pt\")\n",
    "    tokenizer = Tokenizer(checkpoints)\n",
    "    fabric = L.Fabric(devices=1, strategy=\"auto\", precision=\"bf16-true\", loggers=None)\n",
    "    max_seq_length, longest_seq_length, longest_seq_ix = get_max_seq_length(train_data)\n",
    "\n",
    "    return validate(fabric, model, val_data, tokenizer, longest_seq_length) \n",
    "\n",
    "\n",
    "    \n",
    "    #wandb.init(id=run.id, project=run.project, resume=\"allow\")\n",
    "    #wandb.log(results['results'])\n",
    "    #wandb.finish()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b431b57c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-15T22:36:43.056126Z",
     "iopub.status.busy": "2024-05-15T22:36:43.055773Z",
     "iopub.status.idle": "2024-05-15T22:45:47.083153Z",
     "shell.execute_reply": "2024-05-15T22:45:47.082595Z"
    },
    "papermill": {
     "duration": 544.030169,
     "end_time": "2024-05-15T22:45:47.084381",
     "exception": false,
     "start_time": "2024-05-15T22:36:43.054212",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Run hounie/lora/jsyf5mt1 (finished)>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "merging weights\n",
      "evaluating\n",
      "Validating ...\n",
      "Recommend a movie for me to watch during the weekend and explain the reason.\n",
      "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
      "\n",
      "### Instruction:\n",
      "Recommend a movie for me to watch during the weekend and explain the reason.\n",
      "\n",
      "### Response:\n",
      "The movie is 'The Avengers', the superheroes from the Marvel Universe.\n",
      "\n",
      "The movie is a must-watch because it is based on the comic books of the same name and shows an epic battle between good and evil.\n",
      "\n",
      "The plot revolves around the superheroes from the Marvel Universe as they fight against an evil alien race called the Chitauri. The superheroes fight for the survival of the human race and\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'val_loss' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 61\u001b[0m\n\u001b[1;32m     58\u001b[0m fabric \u001b[38;5;241m=\u001b[39m L\u001b[38;5;241m.\u001b[39mFabric(devices\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbf16-true\u001b[39m\u001b[38;5;124m\"\u001b[39m, loggers\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m     59\u001b[0m max_seq_length, longest_seq_length, longest_seq_ix \u001b[38;5;241m=\u001b[39m get_max_seq_length(train_data)\n\u001b[0;32m---> 61\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mvalidate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfabric\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlongest_seq_length\u001b[49m\u001b[43m)\u001b[49m \n",
      "File \u001b[0;32m~/miniconda3/envs/lorta/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/lit-gpt/finetune/lora.py:359\u001b[0m, in \u001b[0;36mvalidate\u001b[0;34m(fabric, model, val_data, tokenizer, longest_seq_length)\u001b[0m\n\u001b[1;32m    356\u001b[0m fabric\u001b[38;5;241m.\u001b[39mprint(output)\n\u001b[1;32m    358\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m--> 359\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mval_loss\u001b[49m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'val_loss' is not defined"
     ]
    }
   ],
   "source": [
    "for run in runs:\n",
    "    print(run)\n",
    "    if run.id == 'jsyf5mt1':\n",
    "        weights = run.file(f\"checkpoints/meta-llama/Llama-2-7b-hf/{run.name}_lora_finetuned.pth\")\n",
    "        # create the directory if it doesn't exist\n",
    "        os.makedirs(f\"download/{run.id}\", exist_ok=True)\n",
    "        out = weights.download(f\"download/{run.id}\", replace=True)\n",
    "        adapter_path = out.name\n",
    "        checkpoints =\"checkpoints/meta-llama/Llama-2-7b-hf\"\n",
    "        out_dir= \"sandbox/\"\n",
    "\n",
    "        config = run.config\n",
    "        checkpoints = Path(checkpoints)\n",
    "        conf = Config.from_name(\n",
    "                name=checkpoints.name,\n",
    "                r=config[\"lora_r\"],\n",
    "                alpha=config[\"alpha\"],\n",
    "                dropout=config[\"lora_dropout\"],\n",
    "                to_query=config[\"lora_query\"],\n",
    "                to_key=config[\"lora_key\"],\n",
    "                to_value=config[\"lora_value\"],\n",
    "                to_projection=config[\"lora_projection\"],\n",
    "                to_mlp=config[\"lora_mlp\"],\n",
    "                to_head=config[\"lora_head\"],\n",
    "                joint_qkvp=config[\"joint_qkvp\"],\n",
    "                tensor_lora=config[\"tensor_lora\"],\n",
    "            )\n",
    "        fabric = L.Fabric(devices=1, strategy=\"auto\", precision=\"bf16-true\", plugins=None)\n",
    "        fabric.seed_everything(0)  # same seed for every process to init model (FSDP)\n",
    "        \n",
    "        with fabric.init_module(empty_init=(False)):\n",
    "            model = GPT(conf)\n",
    "        #tokenizer = Tokenizer(io.checkpoint_dir)\n",
    "        check_valid_checkpoint_dir(checkpoints)\n",
    "\n",
    "        if fabric.global_rank == 0:\n",
    "            os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "        checkpoint_path = checkpoints / \"lit_model.pth\"\n",
    "        #print(adapter_path)\n",
    "        adapter = lazy_load(Path(adapter_path))\n",
    "        base = lazy_load(Path(checkpoint_path))\n",
    "        params = {**adapter.sd[\"model\"], **base.sd}\n",
    "        #print(params.keys())\n",
    "        model.load_state_dict(params, strict=True)\n",
    "        #model.eval()\n",
    "\n",
    "        print(\"merging weights\")\n",
    "        \n",
    "        merge_lora_weights(model)\n",
    "\n",
    "        print(\"evaluating\")\n",
    "\n",
    "        data_dir = Path(\"data/alpaca\")\n",
    "        val_data = torch.load(data_dir / \"test.pt\")\n",
    "        train_data = torch.load(data_dir / \"train.pt\")\n",
    "        tokenizer = Tokenizer(checkpoints)\n",
    "        fabric = L.Fabric(devices=1, strategy=\"auto\", precision=\"bf16-true\", loggers=None)\n",
    "        max_seq_length, longest_seq_length, longest_seq_ix = get_max_seq_length(train_data)\n",
    "\n",
    "        result = validate(fabric, model, val_data, tokenizer, longest_seq_length) \n",
    "\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91bd248f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(12.1462)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7c2d1b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-15T22:45:47.093197Z",
     "iopub.status.busy": "2024-05-15T22:45:47.092915Z",
     "iopub.status.idle": "2024-05-15T22:45:47.095943Z",
     "shell.execute_reply": "2024-05-15T22:45:47.095497Z"
    },
    "papermill": {
     "duration": 0.010736,
     "end_time": "2024-05-15T22:45:47.096771",
     "exception": false,
     "start_time": "2024-05-15T22:45:47.086035",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "download/wieemkak/checkpoints/meta-llama/Llama-2-7b-hf/tensor_lora_r_48_joint_heads_joint_layers_joint_qkvp_lora_finetuned.pth\n"
     ]
    }
   ],
   "source": [
    "print(out.name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "notebook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 549.222446,
   "end_time": "2024-05-15T22:45:48.016365",
   "environment_variables": {},
   "exception": null,
   "input_path": "MT_bench.ipynb",
   "output_path": "MT_bench.ipynb",
   "parameters": {},
   "start_time": "2024-05-15T22:36:38.793919",
   "version": "2.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
