{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AlbertTokenizer\n",
    "from torch.utils.data import Dataset\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "\n",
    "\n",
    "def discriminator_collate_fn(batch_data, tokenizer):\n",
    "    dis_text_input_ids, labels = [], []\n",
    "    for item in batch_data:\n",
    "        # sentence1\n",
    "        dis_text = item['text1'] + '[SEP]' + item['text2']\n",
    "        input_ids = tokenizer(dis_text, return_tensors='pt').input_ids.squeeze()\n",
    "\n",
    "        dis_text_input_ids.append(input_ids)\n",
    "        labels.append(torch.tensor(int(item['score']), dtype=torch.long))\n",
    "\n",
    "    dis_text_input_ids = pad_sequence([x for x in dis_text_input_ids],\n",
    "                                      batch_first=True, \n",
    "                                      padding_value=tokenizer.pad_token_id)\n",
    "\n",
    "    return {\n",
    "        'dis_text_input_ids': dis_text_input_ids,\n",
    "        'labels': torch.stack(labels),\n",
    "    }\n",
    "\n",
    "\n",
    "class Config:\n",
    "    cycle = 0\n",
    "    zero_shot = 0  #\n",
    "    data_name = 'mrpc'  #\n",
    "    chinese = 0 # \n",
    "    warm_up_model = True  #\n",
    "    pretrain_dis = False\n",
    "    discriminator_en = 'albert_xxlarge'\n",
    "    discriminator_zh = 'albert_xxlarge'  # roformer_large / roberta_large\n",
    "    pretrained_en = '/cognitive_comp/user/source/model_base/pretrained_en/'\n",
    "    pretrained_zh = '/cognitive_comp/user/source/model_base/pretrained_zh/'\n",
    "    ckpt_model_path = '/cognitive_comp/user/similarity_generation/experiments/lightning_logs/checkpoints/2'\n",
    "    # ckpt_model_path = '/cognitive_comp/user/similarity_generation/all_checkpoints/new_exp7'\n",
    "\n",
    "class SimGanDataset(Dataset):\n",
    "    def __init__(self, data) -> None:\n",
    "        super().__init__()\n",
    "        self.data = data\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.data[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data.num_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import f1_score, accuracy_score\n",
    "import sys, datasets\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "sys.path.append('/cognitive_comp/user/similarity_generation/')\n",
    "from model_utils.sim_gen_model import Discriminator\n",
    "\n",
    "\n",
    "config = Config()\n",
    "data = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/' + config.data_name)\n",
    "dataset = SimGanDataset(data)\n",
    "dis_tokenizer = AlbertTokenizer.from_pretrained(config.pretrained_en + config.discriminator_en)\n",
    "# dis_tokenizer = AutoTokenizer.from_pretrained(config.pretrained_zh + config.discriminator_zh)\n",
    "def collate_fn(batch_data):\n",
    "    return discriminator_collate_fn(batch_data, dis_tokenizer)\n",
    "dataloader = DataLoader(\n",
    "        dataset=dataset,\n",
    "        batch_size=512,\n",
    "        shuffle=False,\n",
    "        num_workers=4,\n",
    "        pin_memory=True,\n",
    "        collate_fn=collate_fn,\n",
    "    )\n",
    "\n",
    "pred_result = []\n",
    "f1_result, acc_result = [], []\n",
    "for idx in range(11):\n",
    "    config.cycle = idx\n",
    "    discriminator = Discriminator(config)\n",
    "    discriminator.cuda().eval()\n",
    "    with torch.no_grad():\n",
    "        pred_list = []\n",
    "        f1_score_list, acc_score_list = [], []\n",
    "        for batch in dataloader:\n",
    "            torch.cuda.empty_cache()\n",
    "            logits = discriminator.forward(\n",
    "                batch['dis_text_input_ids'].cuda(),\n",
    "                None\n",
    "            )\n",
    "            \n",
    "            predictions = torch.argmax(logits, dim=1).tolist()\n",
    "            f1_score_list.append(\n",
    "                f1_score(batch['labels'].cuda().tolist(), predictions)\n",
    "            )\n",
    "            acc_score_list.append(\n",
    "                accuracy_score(batch['labels'].cuda().tolist(), predictions)\n",
    "            )\n",
    "        print(sum(f1_score_list) / len(f1_score_list))\n",
    "        f1_result.append(sum(f1_score_list) / len(f1_score_list))\n",
    "        print(sum(acc_score_list) / len(acc_score_list))\n",
    "        acc_result.append(sum(acc_score_list) / len(acc_score_list))\n",
    "\n",
    "print(f1_result)\n",
    "print(acc_result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 处理数据集(json->datasets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 处理AFQMC数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data/AFQMC/afqmc_test.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                        #    \"score\": datasets.Value('int8'),\n",
    "                           \"id\": datasets.Value('int64'),\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/afqmc_test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "label_ds = datasets.load_from_disk('/cognitive_comp/user/source/data_base/similarity_data/labeled_data')\n",
    "afqmc_ds = datasets.load_from_disk('/cognitive_comp/user/source/data_base/similarity_data/afqmc_train')\n",
    "label_afqmc_ds = datasets.concatenate_datasets([label_ds, afqmc_ds])\n",
    "label_afqmc_ds.save_to_disk('/cognitive_comp/user/source/data_base/similarity_data/labeled_afqmc_ds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_afqmc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/afqmc_train_ds')\n",
    "# dev_afqmc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/afqmc')\n",
    "test_afqmc = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data/afqmc_test')\n",
    "afqmc = datasets.concatenate_datasets([train_afqmc, test_afqmc])\n",
    "afqmc = afqmc.shuffle(seed=42)\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data/afqmc_train_ds.json', 'w') as wp:\n",
    "    for idx in tqdm(range(afqmc.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': afqmc[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': afqmc[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data/afqmc_train_ds.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/predict_sentences/afqmc_sentence')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 处理QQP数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, glob\n",
    "from concurrent.futures import ProcessPoolExecutor\n",
    "\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                           \"score\": datasets.Value('int8')})\n",
    "def _generate_cache_arrow(index, path):\n",
    "    print('saving dataset shard {}'.format(index))\n",
    "    ds = (datasets.load_dataset('json', data_files=path,\n",
    "                                cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                                features=feats)['train'])\n",
    "    ds.save_to_disk(os.path.join('/cognitive_comp/user/source/sim_data/translate_data/translate_cache_data', f'0{index}'))\n",
    "    return 'saving dataset shard {} done'.format(index)\n",
    "\n",
    "\n",
    "def generate_cache_arrow(num_proc=1) -> None:\n",
    "    data_dict_paths = []\n",
    "    data_dict_paths = glob.glob('/cognitive_comp/user/source/sim_data/translate_data/translate_json_data/*')\n",
    "    print(data_dict_paths)\n",
    "    \n",
    "    p = ProcessPoolExecutor(max_workers=num_proc)\n",
    "    res = []\n",
    "\n",
    "    for index, path in enumerate(data_dict_paths):\n",
    "        res.append(p.submit(_generate_cache_arrow, index, path))\n",
    "\n",
    "    p.shutdown(wait=True)\n",
    "    for future in res:\n",
    "        print(future.result(), flush=True)\n",
    "\n",
    "generate_cache_arrow()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "cache_dict_paths = glob.glob('/cognitive_comp/user/source/sim_data/translate_data/translate_cache_data/*')\n",
    "sim_ds_list = []\n",
    "for path in cache_dict_paths:\n",
    "    sim_ds_list.append(datasets.load_from_disk(path))\n",
    "sim_dataset = datasets.concatenate_datasets(sim_ds_list)\n",
    "sim_dataset.save_to_disk('/cognitive_comp/user/source/sim_data/translate_data/qqp_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_dataset = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/translate_data/qqp_data')\n",
    "split_ds = sim_dataset.train_test_split(test_size=0.4)\n",
    "split_ds['train'].save_to_disk('/cognitive_comp/user/source/sim_data/sim_train_data/qqp_train_ds')\n",
    "split_ds['test'].save_to_disk('/cognitive_comp/user/source/sim_data/sim_test_data/qqp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_qqp = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/qqp_train_ds')\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data/qqp_train_data.json', 'w') as wp:\n",
    "    for idx in tqdm(range(train_qqp.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': train_qqp[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': train_qqp[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data/qqp_train_data.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk(os.path.join('/cognitive_comp/user/source/sim_data/predict_sentences/qqp_sentence'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 处理CHIP数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/test.json', 'r') as rp:\n",
    "    data = json.load(rp)\n",
    "rp.close()\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/chip_test.json', 'w') as wp:\n",
    "    for item in tqdm(data):\n",
    "        wp.write(json.dumps({'text1': item['text1'],\n",
    "                             'text2': item['text2'],\n",
    "                            #  'score': item['label']\n",
    "                             },\n",
    "                            ensure_ascii=False) + '\\n')\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data/CHIP/chip_test.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                        #    \"score\": datasets.Value('int8')\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/chip_test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_chip = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/chip_train_ds')\n",
    "# dev_chip = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/chip')\n",
    "test_chip = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data/chip_test')\n",
    "chip = datasets.concatenate_datasets([train_chip, test_chip])\n",
    "chip = chip.shuffle(seed=42)\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data/chip_train_ds.json', 'w') as wp:\n",
    "    for idx in tqdm(range(chip.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': chip[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': chip[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data/chip_train_ds.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/predict_sentences/chip_sentence')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 处理OPPO数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "# test, train, dev: 50000 167173 10000\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/oppp.json', 'r') as rp:\n",
    "    data = json.load(rp)\n",
    "rp.close()\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/oppo_train.json', 'w') as wp:\n",
    "    for item in tqdm(data['train']):\n",
    "        wp.write(json.dumps({'text1': item['q1'],\n",
    "                             'text2': item['q2'],\n",
    "                             'score': item['label'],\n",
    "                             }, ensure_ascii=False) + '\\n')\n",
    "wp.close()\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/oppo_dev.json', 'w') as wp1:\n",
    "    for item in tqdm(data['dev']):\n",
    "        wp1.write(json.dumps({'text1': item['q1'],\n",
    "                             'text2': item['q2'],\n",
    "                             'score': item['label'],\n",
    "                             }, ensure_ascii=False) + '\\n')\n",
    "wp1.close()\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/oppo_test.json', 'w') as wp2:\n",
    "    for item in tqdm(data['test']):\n",
    "        wp2.write(json.dumps({'text1': item['q1'],\n",
    "                             'text2': item['q2'],\n",
    "                             }, ensure_ascii=False) + '\\n')\n",
    "wp2.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data/oppo_test.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                        #    \"score\": datasets.Value('int8')\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "# ds.save_to_disk('/cognitive_comp/user/source/sim_data/sim_train_data/oppo_train')\n",
    "# ds.save_to_disk('/cognitive_comp/user/source/sim_data/sim_test_data/oppo')\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/oppo_test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_oppo = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/oppo_train_ds')\n",
    "dev_oppo = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/oppo')\n",
    "test_oppo = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data/oppo_test')\n",
    "oppo = datasets.concatenate_datasets([train_oppo, dev_oppo, test_oppo])\n",
    "oppo = oppo.shuffle(seed=42)\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data/oppo_train_ds.json', 'w') as wp:\n",
    "    for idx in tqdm(range(oppo.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': oppo[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': oppo[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data/oppo_train_ds.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/predict_sentences/oppo_sentence')  # 454346"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 处理PAWS数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "paws_train = pd.read_csv('/cognitive_comp/user/source/sim_data/raw_data/paws_test.tsv', sep='\\t', header=None)\n",
    "with open('/cognitive_comp/user/source/sim_data/raw_data/paws_test.json', 'w') as wp:\n",
    "    for idx in tqdm(range(len(paws_train))):\n",
    "        wp.write(json.dumps({'text1': str(paws_train[0][idx]),\n",
    "                             'text2': str(paws_train[1][idx]),\n",
    "                            #  'score': int(paws_train[2][idx]),\n",
    "                             }, ensure_ascii=False) + '\\n')\n",
    "wp.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "path = '/cognitive_comp/user/source/sim_data/raw_data/paws_test.json'\n",
    "feats = datasets.Features({\"text1\": datasets.Value('string'), \n",
    "                           \"text2\": datasets.Value('string'),\n",
    "                        #    \"score\": datasets.Value('int8')\n",
    "                           })\n",
    "ds = (datasets.load_dataset('json', data_files=path, \n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache',\n",
    "                            features=feats)['train'])\n",
    "# ds.save_to_disk('/cognitive_comp/user/source/sim_data/sim_train_data/paws_train')\n",
    "# ds.save_to_disk('/cognitive_comp/user/source/sim_data/sim_test_data/paws')\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/paws_test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, json\n",
    "from tqdm import tqdm\n",
    "\n",
    "train_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/paws_train')\n",
    "dev_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/paws')\n",
    "test_paws = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data/paws_test')\n",
    "paws = datasets.concatenate_datasets([train_paws, dev_paws, test_paws])\n",
    "paws = paws.shuffle(seed=42)\n",
    "\n",
    "with open('/cognitive_comp/user/source/sim_data/similarity_data/paws_train_ds.json', 'w') as wp:\n",
    "    for idx in tqdm(range(paws.num_rows)):\n",
    "        wp.write(json.dumps({'sentence': paws[idx]['text1']}, ensure_ascii=False) + '\\n')\n",
    "        wp.write(json.dumps({'sentence': paws[idx]['text2']}, ensure_ascii=False) + '\\n')\n",
    "        wp.flush()\n",
    "wp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (datasets.load_dataset('json', data_files='/cognitive_comp/user/source/sim_data/similarity_data/paws_train_ds.json',\n",
    "                            cache_dir='/cognitive_comp/user/source/data_base/huggingface-cache')['train'])\n",
    "ds.save_to_disk('/cognitive_comp/user/source/sim_data/predict_sentences/paws_sentence')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 合并数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets, glob\n",
    "\n",
    "\n",
    "cache_dict_paths = glob.glob('/cognitive_comp/user/source/sim_data/similarity_data/sim_cache_data/*')\n",
    "ds = []\n",
    "for path in cache_dict_paths:\n",
    "    ds.append(datasets.load_from_disk(path))\n",
    "\n",
    "afqmc_train = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/afqmc_train_ds')\n",
    "afqmc_dev = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/afqmc')\n",
    "afqmc = datasets.concatenate_datasets([afqmc_train, afqmc_dev])\n",
    "ds.append(afqmc)\n",
    "\n",
    "chip_train = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/chip_train_ds')\n",
    "chip_dev = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/chip')\n",
    "chip = datasets.concatenate_datasets([chip_train, chip_dev])\n",
    "ds.append(chip)\n",
    "\n",
    "oppo_train = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_train_data/oppo_train_ds')\n",
    "oppo_dev = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/sim_test_data/oppo')\n",
    "oppo = datasets.concatenate_datasets([oppo_train, oppo_dev])\n",
    "ds.append(oppo)\n",
    "\n",
    "print(len(ds))\n",
    "label_ds = datasets.concatenate_datasets(ds)\n",
    "label_ds.save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/labeled4paws')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data/labeled_train_chip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "labeled_ds = datasets.load_from_disk('/cognitive_comp/user/source/sim_data/similarity_data/labeled4chip')\n",
    "split_data = labeled_ds.train_test_split(test_size=0.02, seed=42)\n",
    "split_data['train'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/labeled_train_chip')\n",
    "split_data['test'].save_to_disk('/cognitive_comp/user/source/sim_data/similarity_data/labeled_test_chip')\n",
    "print(labeled_ds, split_data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2579772665e14a92e42a466d49ec9ff85683bc5df8d0c675aa9afdde4fd8e604"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
