{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kongxy/miniconda3/envs/LLaRA/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import inspect\n",
    "import torch\n",
    "import importlib\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "import torch.optim.lr_scheduler as lrs\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
    "import random\n",
    "from pandas.core.frame import DataFrame\n",
    "import os.path as op\n",
    "import os\n",
    "from optims import LinearWarmupCosineLRScheduler\n",
    "import numpy as np\n",
    "from model import *\n",
    "from recommender import *\n",
    "# from .peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel, MoeLoraConfig, MoeLoraModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from recommender.A_SASRec_final_bce_llm import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_rec_model(rec_model_path):\n",
    "        print('Loading Rec Model')\n",
    "        # torch.load加载模型\n",
    "        rec_model = torch.load(rec_model_path, map_location=\"cpu\")\n",
    "        rec_model.eval()\n",
    "        # 冻结参数\n",
    "        for name, param in rec_model.named_parameters():\n",
    "            param.requires_grad = False\n",
    "        print('Loding Rec model Done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "item_embeddings.weight torch.Size([4551, 64])\n",
      "positional_embeddings.weight torch.Size([10, 64])\n",
      "ln_1.weight torch.Size([64])\n",
      "ln_1.bias torch.Size([64])\n",
      "ln_2.weight torch.Size([64])\n",
      "ln_2.bias torch.Size([64])\n",
      "ln_3.weight torch.Size([64])\n",
      "ln_3.bias torch.Size([64])\n",
      "mh_attn.linear_q.weight torch.Size([64, 64])\n",
      "mh_attn.linear_q.bias torch.Size([64])\n",
      "mh_attn.linear_k.weight torch.Size([64, 64])\n",
      "mh_attn.linear_k.bias torch.Size([64])\n",
      "mh_attn.linear_v.weight torch.Size([64, 64])\n",
      "mh_attn.linear_v.bias torch.Size([64])\n",
      "feed_forward.w_1.weight torch.Size([64, 64, 1])\n",
      "feed_forward.w_1.bias torch.Size([64])\n",
      "feed_forward.w_2.weight torch.Size([64, 64, 1])\n",
      "feed_forward.w_2.bias torch.Size([64])\n",
      "feed_forward.layer_norm.weight torch.Size([64])\n",
      "feed_forward.layer_norm.bias torch.Size([64])\n",
      "s_fc.weight torch.Size([4550, 64])\n",
      "s_fc.bias torch.Size([4550])\n"
     ]
    }
   ],
   "source": [
    "model_weights = torch.load('/home/kongxy/LLaRA_MOE/rec_model/SASRec_gr_best_model.pt', map_location=torch.device('cpu'))\n",
    "# 打印出所有的键和对应的数据尺寸\n",
    "for key in model_weights:\n",
    "    print(key, model_weights[key].size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "item_embeddings.weight torch.Size([4551, 64])\n",
      "positional_embeddings.weight torch.Size([10, 64])\n",
      "ln_1.weight torch.Size([64])\n",
      "ln_1.bias torch.Size([64])\n",
      "ln_2.weight torch.Size([64])\n",
      "ln_2.bias torch.Size([64])\n",
      "ln_3.weight torch.Size([64])\n",
      "ln_3.bias torch.Size([64])\n",
      "mh_attn.linear_q.weight torch.Size([64, 64])\n",
      "mh_attn.linear_q.bias torch.Size([64])\n",
      "mh_attn.linear_k.weight torch.Size([64, 64])\n",
      "mh_attn.linear_k.bias torch.Size([64])\n",
      "mh_attn.linear_v.weight torch.Size([64, 64])\n",
      "mh_attn.linear_v.bias torch.Size([64])\n",
      "feed_forward.w_1.weight torch.Size([64, 64, 1])\n",
      "feed_forward.w_1.bias torch.Size([64])\n",
      "feed_forward.w_2.weight torch.Size([64, 64, 1])\n",
      "feed_forward.w_2.bias torch.Size([64])\n",
      "feed_forward.layer_norm.weight torch.Size([64])\n",
      "feed_forward.layer_norm.bias torch.Size([64])\n",
      "s_fc.weight torch.Size([4550, 64])\n",
      "s_fc.bias torch.Size([4550])\n"
     ]
    }
   ],
   "source": [
    "# 直接迭代状态字典\n",
    "for key, value in model_weights.items():\n",
    "    print(key, value.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rec_model = torch.load('/home/kongxy/LLaRA_MOE/rec_model/SASRec_steam.pt', map_location=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rec_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in rec_model.named_parameters():\n",
    "    print(name, param.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 转化为模型实例\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = SASRec(hidden_size=64, item_num=4550, state_size=10, dropout=0.1, device=device, num_heads=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(model_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SASRec(\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (item_embeddings): Embedding(4551, 64)\n",
       "  (positional_embeddings): Embedding(10, 64)\n",
       "  (emb_dropout): Dropout(p=0.1, inplace=False)\n",
       "  (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "  (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "  (ln_3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "  (mh_attn): MultiHeadAttention(\n",
       "    (linear_q): Linear(in_features=64, out_features=64, bias=True)\n",
       "    (linear_k): Linear(in_features=64, out_features=64, bias=True)\n",
       "    (linear_v): Linear(in_features=64, out_features=64, bias=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "    (softmax): Softmax(dim=-1)\n",
       "  )\n",
       "  (feed_forward): PositionwiseFeedForward(\n",
       "    (w_1): Conv1d(64, 64, kernel_size=(1,), stride=(1,))\n",
       "    (w_2): Conv1d(64, 64, kernel_size=(1,), stride=(1,))\n",
       "    (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (s_fc): Linear(in_features=64, out_features=4550, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, './rec_model/SASRec_goodreads.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "LLaRA",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
