{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27a25e95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2bf19db",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import torch\n",
    "import random\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Dataset as ds\n",
    "from sklearn.metrics import f1_score, accuracy_score, matthews_corrcoef\n",
    "from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f71cb2a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed) \n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed) \n",
    "\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cd6370",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a40bfbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"roberta-large\" #bert-base-uncased FacebookAI/roberta-base large microsoft/deberta-v3-base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3ea970e",
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = './data/sarcasm_v2/'\n",
    "jsonl_files = glob.glob(os.path.join(directory, '*.csv'))\n",
    "file_names = [os.path.basename(f).split('.')[0] for f in jsonl_files]\n",
    "train_file_names = [f for f in file_names if f.split('_')[-1] =='train']\n",
    "test_file_names = [f for f in file_names if f.split('_')[-1] =='test']\n",
    "train_file_names.sort()\n",
    "test_file_names.sort()\n",
    "print(train_file_names)\n",
    "print(test_file_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac2d4ae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f5428c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.read_csv(f'{directory}{train_file_names[idx]}.csv').reset_index(drop=True)\n",
    "test_df = pd.read_csv(f'{directory}{test_file_names[idx]}.csv').reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18e0ad68",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_class(labels):\n",
    "    return 1 if labels == 'sarc' else 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62b7e08b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df['class'] = train_df['labels'].apply(convert_to_class)\n",
    "test_df['class'] = test_df['labels'].apply(convert_to_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca9469df",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4a2fa31",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomTextDataset(ds):\n",
    "    def __init__(self, df, tokenizer, max_length=128):\n",
    "        self.texts = df['sentence']\n",
    "        self.labels = df['class']\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.texts)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        encoding = self.tokenizer(\n",
    "            self.texts[idx],\n",
    "            truncation=True,\n",
    "            padding='max_length',\n",
    "            max_length=self.max_length,\n",
    "            return_tensors='pt'\n",
    "        )\n",
    "        \n",
    "        item = {key: val.squeeze(0) for key, val in encoding.items()}\n",
    "        item['labels'] = torch.tensor(self.labels[idx])\n",
    "        return item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa55a2fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = CustomTextDataset(train_df, tokenizer)\n",
    "test_dataset = CustomTextDataset(test_df, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec58d479",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(train_dataset,batch_size=4,shuffle=True)\n",
    "test_dataloader = DataLoader(test_dataset,batch_size=4,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f77dd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./results\",\n",
    "    per_device_train_batch_size=32,\n",
    "    num_train_epochs=5,\n",
    "    logging_steps=100,\n",
    "    save_steps=500,\n",
    "    save_strategy=\"steps\",\n",
    "    load_best_model_at_end=False, \n",
    "    report_to=\"none\",\n",
    "    remove_unused_columns=False,\n",
    "    fp16=True, \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d796388",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92481b49",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dda72c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval() \n",
    "all_preds = []\n",
    "all_labels = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(test_dataloader):\n",
    "        input_ids = batch['input_ids'].to(model.device)\n",
    "        attention_mask = batch['attention_mask'].to(model.device)\n",
    "        labels = batch['labels'].to(model.device)\n",
    "\n",
    "        outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "        logits = outputs.logits\n",
    "        preds = torch.argmax(logits, dim=-1)\n",
    "\n",
    "        all_preds.extend(preds.cpu().numpy())\n",
    "        all_labels.extend(labels.cpu().numpy())\n",
    "\n",
    "acc = accuracy_score(all_labels, all_preds)\n",
    "mcc = matthews_corrcoef(all_labels, all_preds)\n",
    "f1 = f1_score(all_labels, all_preds)\n",
    "\n",
    "print(f\"Accuracy: {acc:.3f}\")\n",
    "print(f\"f1: {f1:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b6bdb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file_names[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a341ce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = model_path.split('/')[-1]\n",
    "data_name = train_file_names[idx].split('_')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d2dfd74",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f'./models/{model_name}/{data_name}',exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a47b41b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(f'./models/{model_name}/{data_name}')\n",
    "tokenizer.save_pretrained(f'./models/{model_name}/{data_name}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "topo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
