{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "64464f37-25a4-47ad-9aa7-e212e58eebc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import transformers\n",
    "import torch\n",
    "import numpy as np\n",
    "from Code.utinity import *\n",
    "from Code.models import *\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "800f3c6d-dc4b-42bd-92b6-ea72cbf2b9ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LLM_u(LLM):\n",
    "\n",
    "\n",
    "    def get_attention_matrix(self,outputs):\n",
    "        A = 0\n",
    "        for layer in range(len(outputs.attentions)):\n",
    "            X = outputs.attentions[layer][0].to(torch.float32).cpu().detach().numpy()\n",
    "            for i in range(len(X)):\n",
    "                Ai = (X[i] + X[i].T)/2\n",
    "                # Ai = (X[i] + X[i].T)/2 + abs(X[i] - X[i].T)/2\n",
    "                A = (A+Ai)/2 + abs(A-Ai)/2 # max\n",
    "        # plt.matshow(A, cmap=plt.cm.Reds);\n",
    "        A = np.triu(A) # **** 修改\n",
    "        return A"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc43ba7-3c39-49d9-b89c-2d6f15ae6c1e",
   "metadata": {},
   "source": [
    "## 算法设定\n",
    "\n",
    "以 Llama2 为基础模型，比较现有算法（baseline）和我们设计的算法（ours）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "75aedcf7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "You are using a model of type RefinedWebModel to instantiate a model of type falcon. This is not supported for all configurations of models and can yield errors.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7e055d291a5457dbd7191e7b7f2dd29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "falcon 调用成功!\n"
     ]
    }
   ],
   "source": [
    "llm = LLM_u('falcon',load_model=True)\n",
    "# llm = LLM_u('gpt2',load_model=True)\n",
    "llm.model_name = 'falconuu_test_echo'\n",
    "solver_baseline = Solver_Baseline(llm) # baseline \n",
    "solver_ours_vv = Solver_Echo(llm,repeat=2)\n",
    "solver_ours_vvv = Solver_Echo(llm,repeat=3)\n",
    "solver_ours_vvvv = Solver_Echo(llm,repeat=4)\n",
    "solver_ours_vvvvv = Solver_Echo(llm, repeat=5)\n",
    "# solver_ours_vvvv = Solver_Ours(llm,repeat=4,attention=True)\n",
    "solvers = [solver_baseline,solver_ours_vv,solver_ours_vvv,solver_ours_vvvv,solver_ours_vvvvv]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b5a9fa-a236-48b7-b20b-f7b2313d8cf4",
   "metadata": {},
   "source": [
    "## 算法验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "83619f50-4c36-46b1-9edd-fb45ffba5e03",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "accuracy = {}\n",
    "# def RunEval(problem,use_temp=True):\n",
    "def RunEval(problem,use_temp=True):\n",
    "    results[problem.name] = []\n",
    "    accuracy[problem.name] = []\n",
    "    for i,solver in enumerate(solvers):\n",
    "        # if i == 3:\n",
    "        #     use_temp = False\n",
    "        evaluation = Evaluation(problem,solver,later=False,use_temp=use_temp)\n",
    "        df1,df2 = evaluation.save(filename='result/'+problem.name+'_'+solver.name+'.xlsx')\n",
    "        results[problem.name].append(df1)\n",
    "        accuracy[problem.name].append(df2)\n",
    "    # accuracy[problem.name] = pd.concat([df2,df4],axis=1).T # 添加结果\n",
    "    return pd.concat(accuracy[problem.name],axis=1).T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8adb696-771c-45e2-bb89-6ac87d510876",
   "metadata": {},
   "source": [
    "### 数据集 1: SLPWC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "037d3ae7-83b0-4faa-a37a-a6a993dac4f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "读取文件： Temp/temp_embeddings_SLPWC_falconuu_test_echo_baseline.dat\n",
      "读取文件： Temp/temp_embeddings_SLPWC_falconuu_test_echo_echovv.dat\n",
      "读取文件： Temp/temp_embeddings_SLPWC_falconuu_test_echo_echovvv.dat\n",
      "利用 falconuu_test_echo_echovvvv 计算目标词的向量嵌入：\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48e644a0fc1442a28f8dc4226b43a31c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/300 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call `model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "保存文件： Temp/temp_embeddings_SLPWC_falconuu_test_echo_echovvvv.dat\n",
      "利用 falconuu_test_echo_echovvvvv 计算目标词的向量嵌入：\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "67757dd101f245a18e4605b2d9afb7fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/300 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "保存文件： Temp/temp_embeddings_SLPWC_falconuu_test_echo_echovvvvv.dat\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>distance_shift</th>\n",
       "      <th>cos_similarity_shift</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_baseline</th>\n",
       "      <td>0.583333</td>\n",
       "      <td>0.590000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovv</th>\n",
       "      <td>0.540000</td>\n",
       "      <td>0.546667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovvv</th>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.496667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovvvv</th>\n",
       "      <td>0.463333</td>\n",
       "      <td>0.466667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovvvvv</th>\n",
       "      <td>0.483333</td>\n",
       "      <td>0.490000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              distance_shift  cos_similarity_shift\n",
       "falconuu_test_echo_baseline         0.583333              0.590000\n",
       "falconuu_test_echo_echovv           0.540000              0.546667\n",
       "falconuu_test_echo_echovvv          0.500000              0.496667\n",
       "falconuu_test_echo_echovvvv         0.463333              0.466667\n",
       "falconuu_test_echo_echovvvvv        0.483333              0.490000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "RunEval(problem=Problem(\"Data/SLPWC_v1.csv\",name='SLPWC'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96992458-6533-4a2d-8877-03a571d53d78",
   "metadata": {},
   "source": [
    "## 数据集 2: WSD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d2b4ca51-56a4-4461-a14f-af96071ba0f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "读取文件： Temp/temp_embeddings_WSD_falconuu_test_echo_baseline.dat\n",
      "读取文件： Temp/temp_embeddings_WSD_falconuu_test_echo_echovv.dat\n",
      "读取文件： Temp/temp_embeddings_WSD_falconuu_test_echo_echovvv.dat\n",
      "利用 falconuu_test_echo_echovvvv 计算目标词的向量嵌入：\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e246d4630fdf4a59ba46e2477d1683a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/862 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "没有找到词: 名称 \"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\n",
      "没有找到词: 那里 这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的\n",
      "没有找到词: 精简 精选人物精选人物精选人物精选人物\n",
      "没有找到词: 摄影 骋足则能追风摄景骋足则能追风摄景骋足则能追风摄景骋足则能追风摄景\n",
      "保存文件： Temp/temp_embeddings_WSD_falconuu_test_echo_echovvvv.dat\n",
      "利用 falconuu_test_echo_echovvvvv 计算目标词的向量嵌入：\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56602962c20d448c9df6fb4d84273d1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/862 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "没有找到词: 名称 \"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\"君子疾没世而名不称焉。\n",
      "没有找到词: 那里 这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的这儿瞧瞧,那儿看看,觉得什么都是新鲜的\n",
      "没有找到词: 精简 精选人物精选人物精选人物精选人物精选人物\n",
      "没有找到词: 摄影 骋足则能追风摄景骋足则能追风摄景骋足则能追风摄景骋足则能追风摄景骋足则能追风摄景\n",
      "保存文件： Temp/temp_embeddings_WSD_falconuu_test_echo_echovvvvv.dat\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>distance_shift</th>\n",
       "      <th>cos_similarity_shift</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_baseline</th>\n",
       "      <td>0.446636</td>\n",
       "      <td>0.444316</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovv</th>\n",
       "      <td>0.334107</td>\n",
       "      <td>0.339907</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovvv</th>\n",
       "      <td>0.315545</td>\n",
       "      <td>0.328306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovvvv</th>\n",
       "      <td>0.323666</td>\n",
       "      <td>0.321346</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>falconuu_test_echo_echovvvvv</th>\n",
       "      <td>0.324826</td>\n",
       "      <td>0.319026</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              distance_shift  cos_similarity_shift\n",
       "falconuu_test_echo_baseline         0.446636              0.444316\n",
       "falconuu_test_echo_echovv           0.334107              0.339907\n",
       "falconuu_test_echo_echovvv          0.315545              0.328306\n",
       "falconuu_test_echo_echovvvv         0.323666              0.321346\n",
       "falconuu_test_echo_echovvvvv        0.324826              0.319026"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "RunEval(problem=Problem('Data/WSD_v1.csv',name='WSD'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
