{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import hydra\n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LogNorm\n",
    "import seaborn as sns\n",
    "sys.path.insert(0, \"..\")\n",
    "from floral.dataset import get_data\n",
    "from floral.floral import Floral\n",
    "from floral.client import FloralClient\n",
    "from floral.model import Router\n",
    "\n",
    "mpl.rcParams['pdf.fonttype'] = 42\n",
    "mpl.rcParams['ps.fonttype'] = 42\n",
    "mpl.rcParams['figure.figsize'] = (12, 9)\n",
    "\n",
    "TASK = \"synthetic_linear\"\n",
    "# MODEL_DIR = f\"/Users/zelig/Desktop/code/zeligism/FLoRAL-flower/outputs/test_{TASK}/id=test_clustering,lr=0.1/seed=0/\"\n",
    "MODEL_DIR = f\"/Users/zelig/Desktop/code/zeligism/FLoRAL-flower/outputs/test_{TASK}/id=test_lr_lessdata,lr=0.1/seed=0/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'task': 'synthetic_linear', 'logdir': '???', 'show_cfg': False, 'wandb': False, 'is_rnn': False, 'experiment': 'experiment', 'identifier': 'identifier', 'num_rounds': 1000, 'local_epochs': 1, 'model': {'_target_': 'torch.nn.Linear', 'in_features': '${dataset.dim}', 'out_features': '${dataset.dim_out}'}, 'dataset': {'_target_': 'floral.dataset.SyntheticDataset', 'linear': True, 'simple': False, 'data_path': 'data', 'num_clients': 10, 'num_clusters': 2, 'samples_per_client': '???', 'dim': 10, 'dim_out': 3, 'uv_constant': 2.0, 'rank': 1, 'label_noise_std': 0.0}, 'deterministic': True, 'seed': 0, 'batch_size': 4, 'test_batch_size': 128, 'train_proportion': 0.8, 'dataloader': {'num_workers': 0}, 'task_dir_prefix': '', 'lr': 0.1, 'lora_lr': '${lr}', 'router_lr': '${lr}', 'router_entropy': 0.0, 'lora_penalty': 0.0, 'weight_decay': 0.0, 'task_dir': '${task_dir_prefix}id=${identifier},lr=${lr}', 'floral': {'num_clients': '${dataset.num_clients}', 'num_clusters': '${dataset.num_clusters}', 'rank': 1, 'constant': 1.0, 'num_clusters_mult': 1.0, 'init_lora_later': False, 'add_batchnorm': True, 'router_noise_std': 1.0, 'router_temp': 1.0, 'router_per_layer': False, 'router_top2_gating': False, 'fuse_params': False}, 'private_modules': None, 'router_diagonal_init': False, 'optimizer': {'_target_': 'torch.optim.SGD', 'lr': '${lr}', 'weight_decay': '${weight_decay}'}, 'loss_fn': {'_target_': 'torch.nn.MSELoss'}, 'regularizer': {'_target_': 'floral.utils.Regularizer', 'regularizers': {'router_entropy': {'parameter': '${router_entropy}', 'function': {'_target_': 'floral.floral.utils.get_router_regularizer'}}}}, 'client': {'_target_': 'floral.client.FloralClient', 'private_dir': '???', 'full_comm': True, 'batch_to_tuple': {'_target_': 'floral.utils.get_batch_to_tuple'}, 'local_epochs': '${local_epochs}', 'is_rnn': '${is_rnn}', 'train_opts': {'router_em': False, 'precond_lora': True}}, 'strategy': {'_target_': 'floral.strategy.FloralAvg', 'proxy_client': '???', 'reweight_loras': True, 'global_lr': 1.0, 'fraction_fit': 1.0, 'fraction_evaluate': 1.0, 'min_available_clients': '${dataset.num_clients}', 'evaluate_metrics_aggregation_fn': {'_target_': 'floral.strategy.get_metrics_aggregation_fn'}, 'fit_metrics_aggregation_fn': {'_target_': 'floral.strategy.get_metrics_aggregation_fn'}}, 'server_config': {'_target_': 'flwr.server.ServerConfig', 'num_rounds': '${num_rounds}'}, 'client_resources': {'num_cpus': 1.0, 'num_gpus': 0.0}, 'client_data_to_param_ratio': 0.25}"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with hydra.initialize(version_base=None, config_path=\"../floral/conf\"):\n",
    "    cfg = hydra.compose(config_name=\"base\", overrides=[f\"task@_global_={TASK}\"])\n",
    "    cfg.task = TASK\n",
    "    cfg.dataset.simple = \"simple\" in TASK\n",
    "cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Floral(\n",
       "  (base_model): Linear(in_features=10, out_features=3, bias=True)\n",
       "  (router): FloralRouter()\n",
       "  (lora_modules): ModuleDict(\n",
       "    (/): LoraLinearExperts()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = hydra.utils.instantiate(cfg.model)\n",
    "model = Floral(model, **cfg.floral)\n",
    "state_dict = torch.load(os.path.join(MODEL_DIR, 'model.pt'))\n",
    "model.load_state_dict(state_dict)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['/']"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(model.lora_modules.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  0,  0,  0,  0,  0,  0,  0, -1, -1],\n",
       "        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],\n",
       "        [ 0,  1,  0,  0,  0,  0,  0,  0,  0,  0]], dtype=torch.int32)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.base_model.weight.round().int()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0200, -0.0100, -0.1300, -0.2400, -0.1300,  0.0100, -0.2800,  0.1100,\n",
       "         -0.6300, -0.6500],\n",
       "        [-0.1600,  0.1300,  0.1900, -0.4200, -0.1400,  0.2800,  0.0200,  0.0200,\n",
       "          0.0600, -0.0800],\n",
       "        [ 0.0600,  0.7300, -0.1500,  0.2900, -0.1700, -0.1500,  0.2000,  0.0300,\n",
       "         -0.1200, -0.2900]], grad_fn=<RoundBackward1>)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.round(model.base_model.weight, decimals=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
       "\n",
       "        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=torch.int32)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.lora_modules[\"/\"].fuse().round().int()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0, -1, -1,  0,  1,  0,  0,  0,  0,  0]],\n",
       "\n",
       "        [[-1, -1,  0,  0,  0,  0,  0,  0,  1,  0]]], dtype=torch.int32)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[0],\n",
       "         [0],\n",
       "         [0]],\n",
       "\n",
       "        [[0],\n",
       "         [0],\n",
       "         [0]]], dtype=torch.int32)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(model.lora_modules[\"/\"].weight_in.round().int())\n",
    "display(model.lora_modules[\"/\"].weight_out.round().int())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0., -0., -0., -0., 0., 0., 0., 0., -0., 0.],\n",
       "         [-0., -0., -0., -0., 0., 0., 0., 0., -0., 0.],\n",
       "         [-0., -0., -0., -0., 0., 0., 0., 0., -0., 0.]],\n",
       "\n",
       "        [[-0., -0., -0., 0., 0., -0., -0., 0., 0., 0.],\n",
       "         [-0., -0., -0., 0., 0., -0., -0., 0., 0., 0.],\n",
       "         [-0., -0., -0., 0., 0., -0., -0., 0., 0., 0.]]],\n",
       "       grad_fn=<RoundBackward1>)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.round(model.lora_modules[\"/\"].fuse(), decimals=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[100.,   0.],\n",
       "        [  0., 100.],\n",
       "        [100.,   0.],\n",
       "        [  0., 100.],\n",
       "        [100.,   0.],\n",
       "        [  0., 100.],\n",
       "        [100.,   0.],\n",
       "        [  0., 100.],\n",
       "        [100.,   0.],\n",
       "        [  0., 100.]])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "router_weights = Router.load_router_weights(os.path.join(MODEL_DIR, 'pvt'))\n",
    "torch.stack(router_weights).softmax(-1).mul(100).round()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.dataset.uv_constant = 0.0\n",
    "cfg.floral.constant = 0.0\n",
    "data_loaders, _ = get_data(cfg)\n",
    "train_loaders, test_loaders = zip(*data_loaders)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO flwr 2024-03-16 23:14:04,255 | training.py:60 | Test | Client 0: [1/1] loss=3.0597\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,259 | training.py:60 | Test | Client 1: [1/1] loss=1.9917\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,263 | training.py:60 | Test | Client 2: [1/1] loss=1.0573\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,268 | training.py:60 | Test | Client 3: [1/1] loss=2.5298\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,274 | training.py:60 | Test | Client 4: [1/1] loss=0.4687\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,279 | training.py:60 | Test | Client 5: [1/1] loss=1.9839\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,283 | training.py:60 | Test | Client 6: [1/1] loss=2.4523\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,288 | training.py:60 | Test | Client 7: [1/1] loss=2.2146\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,292 | training.py:60 | Test | Client 8: [1/1] loss=2.5278\tacc=0.0000\n",
      "INFO flwr 2024-03-16 23:14:04,297 | training.py:60 | Test | Client 9: [1/1] loss=0.5341\tacc=0.0000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss = 3.0597381591796875\n",
      "Loss = 1.9916549921035767\n",
      "Loss = 1.0573395490646362\n",
      "Loss = 2.5298025608062744\n",
      "Loss = 0.46870318055152893\n",
      "Loss = 1.9838873147964478\n",
      "Loss = 2.4523427486419678\n",
      "Loss = 2.2145512104034424\n",
      "Loss = 2.5277822017669678\n",
      "Loss = 0.5340678095817566\n"
     ]
    }
   ],
   "source": [
    "from floral.utils import init_device, evaluate\n",
    "loss_fn = hydra.utils.instantiate(cfg.loss_fn)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(cfg.floral.num_clients):\n",
    "        # get model\n",
    "        model = hydra.utils.instantiate(cfg.model)\n",
    "        model = Floral(model, **cfg.floral)\n",
    "        model.client_id = i\n",
    "        state_dict = torch.load(os.path.join(MODEL_DIR, 'model.pt'))\n",
    "        state_dict[\"router.weight\"] = router_weights[i].clone().detach()\n",
    "        model.load_state_dict(state_dict)\n",
    "        # evaluate\n",
    "        metrics = evaluate(model, loss_fn, init_device(), test_loaders[i], i)\n",
    "        print(\"Loss =\", metrics[\"loss\"].get_avg())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = train_loaders[0].dataset.dataset.fl_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "W = dataset.W[0, :, :dataset.dim_out]\n",
    "biases = dataset.biases[:, :dataset.samples_per_client, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-5.9892e-01, -8.4366e-02,  9.8465e-02],\n",
      "        [ 2.9673e-01,  7.3350e-01,  8.4598e-03],\n",
      "        [ 3.2082e-01,  9.7339e-02,  5.2858e-01],\n",
      "        [-3.0456e-01,  1.9034e-01,  1.0540e-01],\n",
      "        [-1.9617e-02,  1.0636e-01,  1.1560e-01],\n",
      "        [-5.0183e-01, -5.1696e-01, -4.8517e-02],\n",
      "        [ 4.4820e-04,  1.3353e-01,  2.0249e-01],\n",
      "        [ 5.4738e-02, -3.0238e-01,  5.7204e-01],\n",
      "        [-1.9363e-01,  4.9456e-01, -3.4126e-01],\n",
      "        [ 1.2901e-01,  1.8344e-01, -1.8919e-01]])\n",
      "tensor([[ 0.0200, -0.1600,  0.0600],\n",
      "        [-0.0100,  0.1300,  0.7300],\n",
      "        [-0.1300,  0.1900, -0.1500],\n",
      "        [-0.2400, -0.4200,  0.2900],\n",
      "        [-0.1300, -0.1400, -0.1700],\n",
      "        [ 0.0100,  0.2800, -0.1500],\n",
      "        [-0.2800,  0.0200,  0.2000],\n",
      "        [ 0.1100,  0.0200,  0.0300],\n",
      "        [-0.6300,  0.0600, -0.1200],\n",
      "        [-0.6500, -0.0800, -0.2900]], grad_fn=<RoundBackward1>)\n"
     ]
    }
   ],
   "source": [
    "print(W)\n",
    "print(torch.round(model.base_model.weight.T, decimals=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(W)\n",
    "print(torch.round(model.base_model.weight.T, decimals=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.3700,  0.5300, -0.2200],\n",
      "        [ 1.2700,  3.5800, -0.3800],\n",
      "        [-1.9200, -2.8400,  0.7200]])\n",
      "tensor([[ 0.3600,  0.1400,  0.5100],\n",
      "        [-0.5400, -0.9300,  2.7200],\n",
      "        [ 2.2200, -1.2000, -1.1300]], grad_fn=<RoundBackward1>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(3, dataset.dim)\n",
    "h_true = x.matmul(W)\n",
    "h_hat = x.matmul(model.base_model.weight.T)\n",
    "print(torch.round(h_true, decimals=2))\n",
    "print(torch.round(h_hat, decimals=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Lora"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cluster = 0\n",
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]])\n",
      "tensor([[-0., -0., -0.],\n",
      "        [-0., -0., -0.],\n",
      "        [-0., -0., -0.],\n",
      "        [-0., -0., -0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [-0., -0., -0.],\n",
      "        [0., 0., 0.]], grad_fn=<RoundBackward1>)\n",
      "Error = 0.0\n",
      "\n",
      "Cluster = 1\n",
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]])\n",
      "tensor([[-0., -0., -0.],\n",
      "        [-0., -0., -0.],\n",
      "        [-0., -0., -0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [-0., -0., -0.],\n",
      "        [-0., -0., -0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]], grad_fn=<RoundBackward1>)\n",
      "Error = 0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for cluster_id in range(cfg.floral.num_clusters):\n",
    "    print(\"Cluster =\", cluster_id)\n",
    "    uv = dataset.Ws[cluster_id] - dataset.W[0]\n",
    "    uv = uv[:, :cfg.dataset.dim_out]\n",
    "    uv_hats = model.lora_modules['/'].fuse().transpose(1,2)\n",
    "    print(uv)\n",
    "    print(torch.round(uv_hats[cluster_id], decimals=2))\n",
    "    print(\"Error =\", (uv.unsqueeze(0) - uv_hats).flatten(1).pow(2).sum(1).sqrt().min().item())\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cluster = 0\n",
      "tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])\n",
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]], grad_fn=<RoundBackward1>)\n",
      "\n",
      "Cluster = 1\n",
      "tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])\n",
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]], grad_fn=<RoundBackward1>)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for cluster_id in range(cfg.floral.num_clusters):\n",
    "    print(\"Cluster =\", cluster_id)\n",
    "    h_true = x.matmul(dataset.Ws[cluster_id] - dataset.W)\n",
    "    h_hat = x.matmul(model.lora_modules['/'].fuse()[cluster_id].T)\n",
    "    print(torch.round(h_true, decimals=2))\n",
    "    print(torch.round(h_hat, decimals=2))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<RoundBackward1>)"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "probs = torch.stack(list(router_weights.values())).softmax(-1)\n",
    "# loras = torch.einsum(\"kc,c...->k...\", (probs, model.lora_modules['/'].fuse()))\n",
    "# signs = loras.sign().prod(0)\n",
    "# gmean = loras.abs().add(1e-10).log().mean(dim=0).exp()\n",
    "# gmean[signs <= 0] = 0.\n",
    "loras = model.lora_modules['/']\n",
    "lora = loras.fuse()\n",
    "signs = lora.sign().prod(0)\n",
    "lora_gmean = lora.abs().add(1e-10).log().mean(dim=0).exp()\n",
    "lora_gmean[signs <= 0] = 0.\n",
    "torch.round(lora_gmean, decimals=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
