{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "\n",
    "    \n",
    "api = wandb.Api()\n",
    "project = \"lora\"\n",
    "workspace = \"username\"\n",
    "\n",
    "# Get our two main experiments so far\n",
    "experiment_tags = [\"baseline\"]#\"mt_eval\"]\n",
    "\n",
    "# get all runs that both: 1.  match any experiment tag and 2. are finished\n",
    "runs = api.runs(f\"{workspace}/{project}\",\n",
    "                {\"$and\": [\n",
    "                    {\"tags\": {\"$in\": experiment_tags}},\n",
    "                ]})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#for run in runs:\n",
    "    #print(run.name) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import lightning as L\n",
    "from lit_gpt.lora import GPT,  Config\n",
    "from lit_gpt.utils import (\n",
    "    check_valid_checkpoint_dir,\n",
    "    lazy_load,\n",
    ")\n",
    "from pathlib import Path\n",
    "\n",
    "import gc\n",
    "\n",
    "import lightning as L\n",
    "from functools import partial\n",
    "\n",
    "from scripts.convert_lit_checkpoint import check_conversion_supported, copy_weights_llama, incremental_save\n",
    "\n",
    "from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights\n",
    "from lit_gpt.model import Config as ModelConfig\n",
    "from lit_gpt.utils import check_valid_checkpoint_dir,  lazy_load\n",
    "\n",
    "import contextlib\n",
    "\n",
    "\n",
    "def setup(run, checkpoints =\"checkpoints/meta-llama/Llama-2-7b-hf\",\n",
    "           adapter_path = \"downloads/\", out_dir= \"merged/\", out_name=\"model.pth\"):\n",
    "    config = run.config\n",
    "    checkpoints = Path(checkpoints)\n",
    "    conf = Config.from_name(\n",
    "            name=checkpoints.name,\n",
    "            r=config[\"lora_r\"],\n",
    "            alpha=config[\"alpha\"],\n",
    "            dropout=config[\"lora_dropout\"],\n",
    "            to_query=config[\"lora_query\"],\n",
    "            to_key=config[\"lora_key\"],\n",
    "            to_value=config[\"lora_value\"],\n",
    "            to_projection=config[\"lora_projection\"],\n",
    "            to_mlp=config[\"lora_mlp\"],\n",
    "            to_head=config[\"lora_head\"],\n",
    "            joint_qkvp=config[\"joint_qkvp\"],\n",
    "            tensor_lora=config[\"tensor_lora\"],\n",
    "        )\n",
    "    fabric = L.Fabric(devices=1, strategy=\"auto\", precision=\"bf16-true\", plugins=None)\n",
    "    fabric.seed_everything(0)  # same seed for every process to init model (FSDP)\n",
    "    \n",
    "    with fabric.init_module(empty_init=(False)):\n",
    "        model = GPT(conf)\n",
    "    #tokenizer = Tokenizer(io.checkpoint_dir)\n",
    "    check_valid_checkpoint_dir(checkpoints)\n",
    "\n",
    "    if fabric.global_rank == 0:\n",
    "        os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "    checkpoint_path = checkpoints / \"lit_model.pth\"\n",
    "    #print(adapter_path)\n",
    "    adapter = lazy_load(Path(adapter_path))\n",
    "    base = lazy_load(Path(checkpoint_path))\n",
    "    params = {**adapter.sd[\"model\"], **base.sd}\n",
    "    #print(params.keys())\n",
    "    model.load_state_dict(params, strict=True)\n",
    "    #model.eval()\n",
    "    \n",
    "    merge_lora_weights(model)\n",
    "\n",
    "    save_path = out_dir / \"lit_model.pth\"\n",
    "    fabric.print(f\"Saving weights to {str(save_path)!r}\")\n",
    "    # remove lora parameters and the lora linear substring\n",
    "    state_dict = {k.replace(\"linear.\", \"\"): v for k, v in model.state_dict().items() if not lora_filter(k, v)}\n",
    "\n",
    "    conf = ModelConfig.from_name(name=checkpoints.name)\n",
    "    copy_fn = partial(copy_weights_llama, conf)\n",
    "\n",
    "    pth_file = out_dir / out_name\n",
    "    bin_file = pth_file.with_suffix(\".bin\")\n",
    "\n",
    "    # initialize a new empty state dict to hold our new weights\n",
    "    sd = {}\n",
    "\n",
    "    with incremental_save(bin_file) as saver:\n",
    "        with contextlib.ExitStack() as stack:\n",
    "            lit_weights = state_dict.get(\"model\", state_dict)\n",
    "            check_conversion_supported(lit_weights)\n",
    "            copy_fn(sd, lit_weights, saver=saver)\n",
    "            gc.collect()\n",
    "        saver.save(sd)\n",
    "\n",
    "\n",
    "    \n",
    "    #wandb.init(id=run.id, project=run.project, resume=\"allow\")\n",
    "    #wandb.log(results['results'])\n",
    "    #wandb.finish()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Run username/lora/7c1k3q9r (finished)>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chiche/miniconda3/envs/lorta/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/chiche/miniconda3/envs/lorta/lib/python3.9/sit ...\n",
      "Seed set to 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using original matrix low rank adapters\n",
      "Saving weights to 'merged/7c1k3q9r/lit_model.pth'\n"
     ]
    }
   ],
   "source": [
    "for run in runs:\n",
    "    print(run)\n",
    "    weights = run.file(f\"checkpoints/meta-llama/Llama-2-7b-hf/{run.name}_lora_finetuned.pth\")\n",
    "    # create the directory if it doesn't exist\n",
    "    os.makedirs(f\"download/{run.id}\", exist_ok=True)\n",
    "    out = weights.download(f\"download/{run.id}\", replace=True)\n",
    "    setup(run, adapter_path=out.name,  out_dir=Path(f\"merged/{run.id}\"))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "download/wieemkak/checkpoints/meta-llama/Llama-2-7b-hf/tensor_lora_r_48_joint_heads_joint_layers_joint_qkvp_lora_finetuned.pth\n"
     ]
    }
   ],
   "source": [
    "print(out.name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "notebook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
