{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "CUDA SETUP: CUDA runtime path found: /home/kongxy/miniconda3/envs/MOE4REC/lib/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n",
      "CUDA SETUP: Detected CUDA version 118\n",
      "CUDA SETUP: Loading binary /home/kongxy/miniconda3/envs/MOE4REC/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...\n"
     ]
    }
   ],
   "source": [
    "import os.path as op\n",
    "import os\n",
    "\n",
    "import inspect\n",
    "import torch\n",
    "import importlib\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "import torch.optim.lr_scheduler as lrs\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
    "import random\n",
    "from pandas.core.frame import DataFrame\n",
    "\n",
    "import numpy as np\n",
    "from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel, MoeLoraConfig, MoeLoraModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 假设 user_embeds 和 gate_weights 是对应的张量\n",
    "user_embeds = torch.randn(4, 1, 4096)  # 示例数据\n",
    "gate_weights = torch.randn(4, 1, 4, 1)  # 示例数据\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将user_embeds沿第一维度分割为数个向量\n",
    "user_embeds_split = user_embeds.split(1, dim=0)\n",
    "user_embed_vectors = [user_embed.squeeze() for user_embed in user_embeds_split]  # 去除第二个向量并去除多余的维度\n",
    "gate_weights_squeezed = gate_weights.squeeze(-1)  # 去掉最后一个维度\n",
    "gate_weights_split = gate_weights_squeezed.split(1, dim=0)  # 沿第一个维度分割\n",
    "gate_weight_vectors = [x.squeeze() for x in gate_weights_split]  # 去除多余的维度\n",
    "# 组合为字典\n",
    "paired_dict = {f'gate_weight_{i+1}': gw.numpy() for i, gw in enumerate(gate_weight_vectors)}\n",
    "paired_dict.update({f'user_embed_{i+1}': gw.numpy() for i, gw in enumerate(user_embed_vectors)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'gate_weight_1': array([ 1.8035704 , -1.4251345 , -0.20917957,  1.1816493 ], dtype=float32), 'gate_weight_2': array([-1.130517  , -0.53926915, -0.25794864, -1.0742768 ], dtype=float32), 'gate_weight_3': array([ 0.38494626, -0.27768913, -0.06496254,  0.42176992], dtype=float32), 'gate_weight_4': array([-0.71604526, -0.8914298 , -0.7220739 ,  0.37865233], dtype=float32), 'user_embed_1': array([-1.4660963 , -1.6457783 , -0.26348984, ...,  0.47030622,\n",
      "       -0.19276133,  1.2298163 ], dtype=float32), 'user_embed_2': array([-0.40329742,  0.41290927, -0.95601976, ..., -1.1843139 ,\n",
      "       -1.3202466 , -0.5464423 ], dtype=float32), 'user_embed_3': array([ 1.463175  , -0.46543968,  0.04877191, ...,  0.07611048,\n",
      "        0.20288888, -0.03222864], dtype=float32), 'user_embed_4': array([-1.6075574, -1.1339449,  0.6124862, ..., -1.1770983,  1.7049038,\n",
      "        0.5721847], dtype=float32)}\n"
     ]
    }
   ],
   "source": [
    "print(paired_dict)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MOE4REC",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
