{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from transformers import BertModel\n",
    "from torch import nn\n",
    "import pandas as pd\n",
    "from transformers import BertTokenizer\n",
    "from nets_bert import BertClassifier\n",
    "from utils_bert import Dataset_2, marich\n",
    "device = torch.device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv(\"./bert_data/bbc/bbc_train.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test =train_test_split(test, test_size = 290)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['business', 'entertainment', 'politics', 'sport', 'tech'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.unique(train[\"Category\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_dict = {np.unique(train[\"Category\"])[i]:i for i in range(5)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_text = train[\"Text\"]\n",
    "# train_labels = np.array(train[\"Category\"].map(labels_dict))\n",
    "\n",
    "# val_text = val[\"Text\"]\n",
    "# val_labels = np.array(val[\"Category\"].map(labels_dict))\n",
    "\n",
    "test_text = test[\"Text\"]\n",
    "test_labels = np.array(test[\"Category\"].map(labels_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_text = test_text.apply(lambda x: x.replace(\"\\\\\", \" \"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = Dataset_2(test_text, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_data = torch.load(\"ag_train.pt\")\n",
    "val_data = torch.load(\"ag_val.pt\")\n",
    "import pickle\n",
    "with open('ag_train_unbal_x.pkl', 'rb') as f:\n",
    "    train_x = pickle.load(f)\n",
    "with open('ag_train_unbal_y.pkl', 'rb') as f:\n",
    "    train_y = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = Dataset_2(train_x, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, random_split\n",
    "batch_size = 8\n",
    "\n",
    "# train_data, validation_data = random_split(train_data, [1000, 200])\n",
    "# train_sampler = RandomSampler(train_data)\n",
    "# train_dataloader = DataLoader(train_data, batch_size=batch_size)\n",
    "\n",
    "# validation_sampler = SequentialSampler(val_data)\n",
    "validation_dataloader = DataLoader(val_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "testset = Dataset_2(test_text, test_labels)\n",
    "test_loader = DataLoader(testset, batch_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_marich = []\n",
    "\n",
    "for i in range(10):\n",
    "    model = BertClassifier().to(device)\n",
    "    tll, vll, tal, samp = marich(model,train_data,validation_dataloader,test_loader,budget = 60, init_points = 100, rounds = 6, epochs = 3, LR = 5e-6, model_name = \"bert_attack\"+str(i)+\".pt\", sampling = \"all_elg\", device = \"cuda\")\n",
    "\n",
    "    acc_marich.append(tal)\n",
    "    np.save(\"./results2/acc_marich_bert.npy\", np.array(acc_marich))\n",
    "    np.save(\"./results2/samples_bert.npy\", np.array(samp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
