{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n",
    "model = AutoModel.from_pretrained(\"facebook/opt-350m\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = model.save_pretrained(\"/cognitive_comp/user/source/model_base/model_en/opt_350m.pt\")\n",
    "_ = tokenizer.save_pretrained(\"/cognitive_comp/user/source/model_base/model_en/opt_350m.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_num = sum(p.numel() for p in model.parameters())\n",
    "print(total_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### QQP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "qqp_train = pd.read_csv('/cognitive_comp/user/source/sim_data/raw_data_en/QQP/qqp_train.tsv', sep='\\t')\n",
    "qqp_dev = pd.read_csv('/cognitive_comp/user/source/sim_data/raw_data_en/QQP/qqp_dev.tsv', sep='\\t')\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/qqp_sim_data.json', 'w') as wp:\n",
    "    for idx in tqdm(range(len(qqp_train))):\n",
    "        if qqp_train['is_duplicate'].iloc[idx] == 1 and str(qqp_train['question1'].iloc[idx]) != \"\" and str(qqp_train['question1'].iloc[idx]) != \"\":\n",
    "            wp.write(json.dumps({'text1': str(qqp_train['question1'].iloc[idx]),\n",
    "                                'text2': str(qqp_train['question2'].iloc[idx]),\n",
    "                                # 'score': int(qqp_train['is_duplicate'].iloc[idx]),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "    \n",
    "    for idx in tqdm(range(len(qqp_dev))):\n",
    "        if qqp_dev['is_duplicate'].iloc[idx] == 1 and str(qqp_dev['question1'].iloc[idx]) != \"\" and str(qqp_dev['question1'].iloc[idx]) != \"\":\n",
    "            wp.write(json.dumps({'text1': str(qqp_dev['question1'].iloc[idx]),\n",
    "                                 'text2': str(qqp_dev['question2'].iloc[idx]),\n",
    "                                # 'score': int(qqp_dev['is_duplicate'].iloc[idx]),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/qqp_sim_data.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_sim_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "qqp_data = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_sim_data')\n",
    "split_data = qqp_data.train_test_split(test_size=0.08, seed=42)\n",
    "split_data['train'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_train')\n",
    "split_data['test'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "qqp_train = pd.read_csv('/cognitive_comp/user/source/sim_data/raw_data_en/QQP/qqp_train.tsv', sep='\\t')\n",
    "qqp_dev = pd.read_csv('/cognitive_comp/user/source/sim_data/raw_data_en/QQP/qqp_dev.tsv', sep='\\t')\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/qqp_sim_data.json', 'w') as wp:\n",
    "    for idx in tqdm(range(len(qqp_train))):\n",
    "        if str(qqp_train['question1'].iloc[idx]) != \"\" and str(qqp_train['question1'].iloc[idx]) != \"\":\n",
    "            wp.write(json.dumps({'text1': str(qqp_train['question1'].iloc[idx]),\n",
    "                                'text2': str(qqp_train['question2'].iloc[idx]),\n",
    "                                'score': int(qqp_train['is_duplicate'].iloc[idx]),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "    \n",
    "    for idx in tqdm(range(len(qqp_dev))):\n",
    "        if str(qqp_dev['question1'].iloc[idx]) != \"\" and str(qqp_dev['question1'].iloc[idx]) != \"\":\n",
    "            wp.write(json.dumps({'text1': str(qqp_dev['question1'].iloc[idx]),\n",
    "                                 'text2': str(qqp_dev['question2'].iloc[idx]),\n",
    "                                'score': int(qqp_dev['is_duplicate'].iloc[idx]),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/qqp_sim_data.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           \"score\": datasets.Value(\"int8\")\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_sim_data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### MRPC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "wp = open('/cognitive_comp/user/source/sim_data/raw_data_en/MRPC/mrpc_dev.json', 'w')\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/MRPC/msr_paraphrase_test.txt', 'r') as rp:\n",
    "    lines = rp.readlines()\n",
    "    for line in tqdm(lines):\n",
    "        line_list = line.split('\\t')\n",
    "        wp.write(json.dumps({'text1': str(line_list[-2]),\n",
    "                             'text2': str(line_list[-1]),\n",
    "                             'score': int(line_list[0][-1]),\n",
    "                            }, ensure_ascii=False) + '\\n')\n",
    "rp.close()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/MRPC/mrpc_dev.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           \"score\": datasets.Value('int8'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/mrpc_dev')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_mrpc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/mrpc_train_ds')\n",
    "# dev_mrpc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/mrpc')\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data_en/mrpc_data.json', 'w') as wp:\n",
    "    for idx in tqdm(range(train_mrpc.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': train_mrpc[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': train_mrpc[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "    # for idx in tqdm(range(dev_mrpc.num_rows)):\n",
    "    #     wp.write(json.dumps({'sentence': dev_mrpc[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "    #     wp.write(json.dumps({'sentence': dev_mrpc[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "    #     wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data_en/mrpc_data.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk(os.path.join('/cognitive_comp/user/source/sim_data/predict_sentences/mrpc_sentence'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### PAWS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"paws-x\", \"en\")\n",
    "dataset.save_to_disk('/cognitive_comp/user/source/sim_data/raw_data_en/paws')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/raw_data_en/paws')\n",
    "paws_train = datasets.concatenate_datasets([paws['train'], paws['validation']])\n",
    "paws_val = paws['test']\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data_en/paws-x_train.json', 'w') as wp:\n",
    "    for idx in tqdm(range(paws_train.num_rows)):\n",
    "        wp.write(json.dumps({'text1': str(paws_train[idx]['sentence1']),\n",
    "                             'text2': str(paws_train[idx]['sentence2']),\n",
    "                             'score': int(paws_train[idx]['label']),\n",
    "                            }, ensure_ascii=False) + '\\n')\n",
    "wp.close()\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data_en/paws-x_dev.json', 'w') as wp1:\n",
    "    for idx in tqdm(range(paws_val.num_rows)):\n",
    "        wp1.write(json.dumps({'text1': str(paws_val[idx]['sentence1']),\n",
    "                             'text2': str(paws_val[idx]['sentence2']),\n",
    "                             'score': int(paws_val[idx]['label']),\n",
    "                            }, ensure_ascii=False) + '\\n')\n",
    "wp1.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path = '/cognitive_comp/user/source/sim_data/similarity_data_en/paws-x_train.json'\n",
    "path = '/cognitive_comp/user/source/sim_data/similarity_data_en/paws-x_dev.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           \"score\": datasets.Value('int8'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "# ds.save_to_disk('/cognitive_comp/user/source/sim_data/sim_train_data/paws_train_ds')\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/sim_test_data/paws')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/paws_train_ds')\n",
    "dev_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/paws')\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data_en/paws_data.json', 'w') as wp:\n",
    "    for idx in tqdm(range(train_paws.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': train_paws[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': train_paws[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "    for idx in tqdm(range(dev_paws.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': dev_paws[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': dev_paws[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data_en/paws_data.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk(os.path.join('/cognitive_comp/user/source/sim_data/predict_sentences/paws_sentence'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/paws_train_ds')\n",
    "dev_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/paws')\n",
    "paws = datasets.concatenate_datasets([train_paws, dev_paws])\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/paws_sim_data.json', 'w') as wp:\n",
    "    for idx in tqdm(range(paws.num_rows)):\n",
    "        if paws[idx]['score'] == 1 and str(paws[idx]['text1']) != \"\" and str(paws[idx]['text2']) != \"\":\n",
    "            wp.write(json.dumps({'text1': str(paws[idx]['text1']),\n",
    "                                 'text2': str(paws[idx]['text2']),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/paws_sim_data.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/paws_sim_data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### STS-B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "wp = open('/cognitive_comp/user/source/sim_data/raw_data_en/STS-B/sts-b.json', 'w')\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/STS-B/train.tsv', 'r') as rp:\n",
    "    lines = rp.readlines()\n",
    "    for line in tqdm(lines):\n",
    "        line_list = line.split('\\t')\n",
    "        if float(line_list[-1][:-1]) >= 4.0:\n",
    "            wp.write(json.dumps({'text1': str(line_list[-3]),\n",
    "                                'text2': str(line_list[-2]),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/STS-B/dev.tsv', 'r') as rp:\n",
    "    lines = rp.readlines()\n",
    "    for line in tqdm(lines):\n",
    "        line_list = line.split('\\t')\n",
    "        if float(line_list[-1][:-1]) >= 4.0:\n",
    "            wp.write(json.dumps({'text1': str(line_list[-3]),\n",
    "                                'text2': str(line_list[-2]),\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "rp.close()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/STS-B/sts-b.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/sts-b_data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### WikiText"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "\n",
    "wikitext = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/raw_data_en/wikitext')\n",
    "wikitext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/wikitext.json', 'w') as wp:\n",
    "    for idx in tqdm(range(wikitext.num_rows)):\n",
    "        if len(wikitext[idx]['text']) != 0 and '=' not in wikitext[idx]['text']:\n",
    "            wp.write(json.dumps({'text1': str(wikitext[idx]['text']),\n",
    "                                 'text2': 'general',\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/wikitext.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/wikitext_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "wiki_data = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/wikitext_data')\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data_en/wikitext.json', 'w') as wp:\n",
    "    for idx in tqdm(range(wiki_data.num_rows)):\n",
    "        if len(wiki_data[idx]['text1']) >= 50 and '<unk>' not in wiki_data[idx]['text1'] and '@' not in wiki_data[idx]['text1']:\n",
    "            wp.write(json.dumps({'text1': str(wiki_data[idx]['text1']),\n",
    "                                 'text2': 'general',\n",
    "                                }, ensure_ascii=False) + '\\n')\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '/cognitive_comp/user/source/sim_data/raw_data_en/wikitext.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/wikitext_data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 合并数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "wikitext = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/wikitext_data')\n",
    "qqp = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_sim_data')\n",
    "sts_b = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/sts-b_data')\n",
    "paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/paws_sim_data')\n",
    "\n",
    "pretrain_data = datasets.concatenate_datasets([wikitext, qqp, sts_b, paws])\n",
    "pretrain_data = pretrain_data.shuffle(seed=42)\n",
    "pretrain_data.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/pretrain_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "\n",
    "pretrain_data = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/pretrain_data')\n",
    "pretrain_data = pretrain_data.train_test_split(test_size=0.03)\n",
    "pretrain_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrain_data['train'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/pre_train')\n",
    "pretrain_data['test'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/pre_val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "qqp = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/qqp_sim_data')\n",
    "\n",
    "# train_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/paws_train_ds')\n",
    "# dev_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/paws')\n",
    "\n",
    "train_mrpc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/mrpc_train_ds')\n",
    "dev_mrpc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/mrpc')\n",
    "\n",
    "data4paws = datasets.concatenate_datasets([qqp, train_mrpc, dev_mrpc])\n",
    "data4paws = data4paws.shuffle(seed=42)\n",
    "data4paws.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/labeled4paws')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data4paws = data4paws.train_test_split(test_size=0.03)\n",
    "data4paws['train'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/labeled_train_paws')\n",
    "data4paws['test'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data_en/labeled_test_paws')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Bert Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, os\n",
    "from bert_score import score\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "for i in [0, 8]:  # [0, 10]\n",
    "    raw_text, sim_text = [], []\n",
    "    data = datasets.load_from_disk('/cognitive_comp/user/similarity_generation/consistency/afqmc/data_cycle_' + str(i))\n",
    "    for j in range(data.num_rows):\n",
    "        raw_text.append(data[j]['text1'])\n",
    "        sim_text.append(data[j]['text2'])\n",
    "    P, R, F1 = score(sim_text, raw_text, lang=\"zh\", verbose=True)\n",
    "    print(F1.mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Perplexity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### en"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, os, evaluate\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "\n",
    "perplexity = evaluate.load(\"perplexity\", module_type=\"metric\")\n",
    "for i in [0, 8]:\n",
    "    ppl, sim_text = [], []\n",
    "    data = datasets.load_from_disk('/cognitive_comp/user/similarity_generation/consistency/mrpc/data_cycle_' + str(i))\n",
    "    for j in tqdm(range(data.num_rows)):\n",
    "        if data[j]['text2']:\n",
    "            sim_text.append(data[j]['text2'])\n",
    "    print(perplexity.compute(input_texts=sim_text, model_id='gpt2')['mean_perplexity'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### zh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from transformers import GPT2Tokenizer,GPT2LMHeadModel\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('IDEA-CCNL/Wenzhong-GPT2-110M')\n",
    "model = GPT2LMHeadModel.from_pretrained('IDEA-CCNL/Wenzhong-GPT2-110M')\n",
    "model.to('cuda').eval()\n",
    "\n",
    "def gpt_ppl(sent):\n",
    "    inputs = tokenizer(sent, return_tensors='pt')\n",
    "    loss = model(input_ids=inputs[\"input_ids\"].cuda(), \n",
    "                 attention_mask=inputs[\"attention_mask\"].cuda(),\n",
    "                 labels=inputs[\"input_ids\"].cuda()).loss\n",
    "    ppl = np.exp(loss.item())\n",
    "\n",
    "    return ppl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "for i in [0, 10]:\n",
    "    ppl, sim_text = [], []\n",
    "    data = datasets.load_from_disk('/cognitive_comp/user/similarity_generation/consistency/qqp/data_cycle_' + str(i))\n",
    "    for j in tqdm(range(data.num_rows)):\n",
    "        if data[j]['text2']:\n",
    "            ppl.append(gpt_ppl(data[j]['text2']))\n",
    "    print(np.array(ppl).mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Consistency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, torch\n",
    "from transformers import BertTokenizer, BertModel\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained(\"princeton-nlp/sup-simcse-bert-base-uncased\")  # princeton-nlp/sup-simcse-bert-base-uncased / SimCSE-bert-base\n",
    "model = BertModel.from_pretrained(\"princeton-nlp/sup-simcse-bert-base-uncased\")\n",
    "model.to('cuda').eval()\n",
    "\n",
    "def get_emb(sent_list):\n",
    "    torch.cuda.empty_cache()\n",
    "    inputs = tokenizer(sent_list, padding=True, return_tensors=\"pt\")\n",
    "    outputs = model(input_ids=inputs[\"input_ids\"].cuda(), \n",
    "                    attention_mask=inputs[\"attention_mask\"].cuda()\n",
    "                ).pooler_output \n",
    "\n",
    "    return outputs.squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from tqdm import tqdm\n",
    "from scipy.stats import wasserstein_distance, entropy\n",
    "\n",
    "\n",
    "all_w, all_e, all_kl, all_h = [], [], [], []\n",
    "for i in range(9):  # [0, 10]\n",
    "    sim_text, raw_text = [], []\n",
    "    cos_sim, prob = [], []\n",
    "    data = datasets.load_from_disk('/cognitive_comp/user/similarity_generation/consistency/mrpc/data_cycle_' + str(i))\n",
    "    for j in tqdm(range(data.num_rows)):\n",
    "        if data[j]['text1'] and data[j]['text2']:\n",
    "            prob.append(data[j]['prob'])\n",
    "            raw_text.append(data[j]['text1'])\n",
    "            sim_text.append(data[j]['text2'])\n",
    "        \n",
    "        if j != 0 and (j % 100 == 0 or j == data.num_rows - 1):\n",
    "            emb1 = get_emb(raw_text).tolist()\n",
    "            emb2 = get_emb(sim_text).tolist()\n",
    "            for e1, e2 in zip(emb1, emb2):\n",
    "                e1, e2 = np.array(e1), np.array(e2)\n",
    "                cos_sim.append(e1.dot(e2) / (np.linalg.norm(e1) * np.linalg.norm(e2)))\n",
    "            sim_text, raw_text = [], []\n",
    "    \n",
    "    del_index = []\n",
    "    for j in range(len(cos_sim)):\n",
    "        if np.isnan(cos_sim[j]) or cos_sim[j] <= 0:\n",
    "            del_index.append(j)\n",
    "    for j in reversed(del_index):\n",
    "        del prob[j]\n",
    "        del cos_sim[j]\n",
    "    \n",
    "    print(len(cos_sim))\n",
    "    \n",
    "    # e_d = np.sqrt(np.sum(np.square(np.array(cos_sim) - np.array(prob))))\n",
    "    # all_e.append(e_d)\n",
    "\n",
    "    kl = entropy(prob, cos_sim)\n",
    "    all_kl.append(kl)\n",
    "\n",
    "    # w_d = wasserstein_distance(cos_sim, prob)\n",
    "    # all_w.append(w_d)\n",
    "\n",
    "    # h_d = 1 / np.sqrt(2) * np.linalg.norm(np.sqrt(cos_sim) - np.sqrt(prob))\n",
    "    # all_h.append(h_d)\n",
    "    \n",
    "    # print(all_e)\n",
    "    print(all_kl)\n",
    "    # print(all_w)\n",
    "    # print(all_h)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2579772665e14a92e42a466d49ec9ff85683bc5df8d0c675aa9afdde4fd8e604"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
