{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28e3208a-f524-40f2-bbe2-74aab0519b42",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import argparse\n",
    "import itertools\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "45457a81-435d-45f1-a972-efeb50bb4782",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "path = \"/root/autodl-tmp/runs/sorsa_qv_ana/\"\n",
    "file_pattern = re.compile(r\"metadata\\.pt_(\\d+)\\.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3fd7b001-ce6f-4b26-a3e7-ac9b4bbb3fb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_x_y(w_0, w_t):\n",
    "    u_t, s_t, vt_t = torch.linalg.svd(w_t, full_matrices=False)\n",
    "    u_0, s_0, vt_0 = w_0\n",
    "    ds = (s_t - s_0).abs().mean()\n",
    "    dd = 1 - ((u_t * u_0).sum(dim=0).abs() + (vt_t * vt_0).sum(dim=1).abs()).mean() / 2\n",
    "    return ds.item(), dd.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0176d29b-ed88-4a84-ba1a-c968d6710e59",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_0 = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aa3bd5d3-ce74-446e-a279-7b3e284451ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d39b6b04-f807-4d79-9d37-78a91909f4be",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for filename in os.listdir(path):\n",
    "    match = file_pattern.match(filename)\n",
    "    if match:\n",
    "        step = int(match.group(1))\n",
    "        file_path = os.path.join(path, filename)\n",
    "        if os.path.exists(file_path):\n",
    "            print(f\"Step: {step}\")\n",
    "            weight_dict = torch.load(file_path, map_location=\"cpu\")\n",
    "            if step == 0:\n",
    "                if bool(w_0) is False:\n",
    "                    progress_bar = tqdm(range(len(weight_dict.keys())))\n",
    "                    for key, value in weight_dict.items():\n",
    "                        u, s, vt = torch.linalg.svd(value.T, full_matrices=False)\n",
    "                        w_0[key] = (u, s, vt)\n",
    "                        progress_bar.update(1)\n",
    "                    progress_bar.close()\n",
    "            elif len(data.get(step, {}).keys()) is not len(weight_dict.keys()):\n",
    "                progress_bar = tqdm(range(len(weight_dict.keys())))\n",
    "                data[step] = {}\n",
    "                for key, value in weight_dict.items():\n",
    "                    x, y = calc_x_y(w_0[key], value.T)\n",
    "                    data[step][key] = (x, y)\n",
    "                    progress_bar.update(1)\n",
    "                progress_bar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f32d82e7-ed5b-4680-a5de-9ad54c043b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = sorted(data.keys())\n",
    "colors = plt.cm.viridis(np.linspace(0, 1, len(steps)))\n",
    "markers_list = [\n",
    "    \"$0$\",\n",
    "    \"$1$\",\n",
    "    \"$2$\",\n",
    "    \"$3$\",\n",
    "    \"$4$\",\n",
    "    \"$5$\",\n",
    "    \"$6$\",\n",
    "    \"$7$\",\n",
    "    \"$8$\",\n",
    "    \"$9$\",\n",
    "    \"$10$\",\n",
    "    \"$11$\",\n",
    "    \"$12$\",\n",
    "    \"$13$\",\n",
    "    \"$14$\",\n",
    "    \"$15$\",\n",
    "    \"$16$\",\n",
    "    \"$17$\",\n",
    "    \"$18$\",\n",
    "    \"$19$\",\n",
    "    \"$20$\",\n",
    "    \"$21$\",\n",
    "    \"$22$\",\n",
    "    \"$23$\",\n",
    "    \"$24$\",\n",
    "    \"$25$\",\n",
    "    \"$26$\",\n",
    "    \"$27$\",\n",
    "    \"$28$\",\n",
    "    \"$29$\",\n",
    "    \"$30$\",\n",
    "    \"$31$\",\n",
    "]\n",
    "markers = itertools.cycle(markers_list)\n",
    "\n",
    "# Initialize plot\n",
    "plt.figure(figsize=(15, 10))\n",
    "\n",
    "# Plot data\n",
    "layer_points = {name: [] for step in data for name in data[step]}\n",
    "for step in steps:\n",
    "    for i, name in enumerate(data[step]):\n",
    "        x, y = data[step][name]\n",
    "        marker = markers_list[i % len(markers_list)]\n",
    "        color = colors[steps.index(step) % len(colors)]\n",
    "        plt.scatter(\n",
    "            x,\n",
    "            y,\n",
    "            label=f\"Step {step}, Layer {name}\",\n",
    "            marker=marker,\n",
    "            color=color,\n",
    "            alpha=0.6,\n",
    "        )\n",
    "        layer_points[name].append((x, y, color))\n",
    "\n",
    "# Calculate and plot mean points for each step\n",
    "plt.scatter(0, 0, color=\"black\", s=100)\n",
    "mean_positions = []\n",
    "for step in steps:\n",
    "    xs, ys = zip(*[data[step][name] for name in data[step]])\n",
    "    mean_x = np.mean(xs)\n",
    "    mean_y = np.mean(ys)\n",
    "    mean_positions.append((mean_x, mean_y, colors[steps.index(step)]))\n",
    "    plt.scatter(mean_x, mean_y, color=colors[steps.index(step)], s=100)\n",
    "\n",
    "# Connect mean points\n",
    "for i in range(len(mean_positions) - 1):\n",
    "    plt.plot(\n",
    "        [mean_positions[i][0], mean_positions[i + 1][0]],\n",
    "        [mean_positions[i][1], mean_positions[i + 1][1]],\n",
    "        color=mean_positions[i][2],\n",
    "        linestyle=\"-\",\n",
    "        linewidth=4,\n",
    "    )\n",
    "\n",
    "# Connect the first mean point to (0, 0) with a black line\n",
    "plt.plot(\n",
    "    [0, mean_positions[0][0]],\n",
    "    [0, mean_positions[0][1]],\n",
    "    color=\"black\",\n",
    "    linestyle=\"-\",\n",
    "    linewidth=4,\n",
    ")\n",
    "\n",
    "# Connect points with same name\n",
    "for name, points in layer_points.items():\n",
    "    if len(points) > 1:\n",
    "        # points.sort()  # Ensure points are sorted by step if needed\n",
    "        xs, ys, cs = zip(*points)\n",
    "        for i in range(len(xs) - 1):\n",
    "            plt.plot(\n",
    "                [xs[i], xs[i + 1]],\n",
    "                [ys[i], ys[i + 1]],\n",
    "                color=cs[i],\n",
    "                linestyle=\"-\",\n",
    "                linewidth=2,\n",
    "                alpha=0.1,  # Set transparency for the connecting lines\n",
    "            )\n",
    "\n",
    "            \n",
    "# Custom legend for steps\n",
    "handles = [\n",
    "    plt.Line2D([0], [0], marker=\"o\", color=color, linestyle=\"\", markersize=10)\n",
    "    for color in colors\n",
    "]\n",
    "labels = [f\"Step {step}\" for step in steps]\n",
    "plt.legend(handles, labels, title=\"Steps\", loc=\"upper left\", bbox_to_anchor=(1, 1))\n",
    "\n",
    "# Custom legend for layers (markers)\n",
    "# handles = [\n",
    "#     plt.Line2D([0], [0], marker=marker, color=\"k\", linestyle=\"\", markersize=10)\n",
    "#     for marker in markers_list\n",
    "# ]\n",
    "# labels = [f\"Layer {i}\" for i in range(len(markers_list))]\n",
    "# layer_legend = plt.legend(\n",
    "#     handles, labels, title=\"Layers\", loc=\"upper right\", bbox_to_anchor=(1, 0.5)\n",
    "# )\n",
    "# plt.gca().add_artist(layer_legend)\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel(\"$\\Delta \\Sigma$\")\n",
    "plt.ylabel(\"$\\Delta D$\")\n",
    "if \"sorsa\" in path:\n",
    "    plt.title(\"SORSA\")\n",
    "elif \"LoRA\" in path:\n",
    "    plt.title(\"LoRA\")\n",
    "else:\n",
    "    plt.title(\"FT\")\n",
    "plt.grid(True)\n",
    "\n",
    "# plt.show()\n",
    "plt.savefig(f\"{path}graph.svg\", format=\"svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d74f23f-d0c7-4849-935f-f33f315f2826",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
