{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.insert(0, \"..\")\n",
    "import os\n",
    "import glob\n",
    "import hydra\n",
    "from floral.training.utils import instantiate_model\n",
    "\n",
    "TASK_CONFIG_DIR = \"../floral/conf/task\"\n",
    "TASK = \"shakespeare\"  # <--- choose task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_config_names(conf_dir):\n",
    "    config_names = []\n",
    "    for fname in glob.glob(os.path.join(conf_dir, \"*\")):\n",
    "        config_name = os.path.basename(fname).replace('.yaml', '')\n",
    "        if not config_name.startswith(\"_\"):\n",
    "            config_names.append(config_name)\n",
    "    return list(sorted(config_names))\n",
    "\n",
    "\n",
    "available_tasks = get_config_names(TASK_CONFIG_DIR)\n",
    "assert TASK in available_tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with hydra.initialize(version_base=None, config_path=\"../floral/conf\"):\n",
    "    cfg = hydra.compose(config_name=\"base\", overrides=[f\"task@_global_={TASK}\",\n",
    "                                                       f\"method@_global_=floral\"])\n",
    "    cfg.task = TASK\n",
    "    cfg.method = \"floral\"\n",
    "    # cfg.floral.rank = 0.999\n",
    "    cfg.floral.num_clusters = 4\n",
    "    cfg.floral.min_rank = 1\n",
    "    # cfg.floral.bias = False\n",
    "    # cfg.floral.use_embeddinglora = False\n",
    "    # cfg.floral.use_convlora = False\n",
    "    # cfg.floral.use_normlora = True\n",
    "    # cfg.floral.convlora_method = \"out\"\n",
    "\n",
    "base_model = hydra.utils.instantiate(cfg.model)\n",
    "model = instantiate_model(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, p in model.state_dict().items():\n",
    "    print(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
