{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c1d20b13",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
      "[nltk_data]   Package stopwords is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import fitz\n",
    "import re\n",
    "import json\n",
    "from datetime import datetime\n",
    "from typing import Optional, List, Callable, Any, Tuple, Dict, Set\n",
    "from abc import abstractmethod, ABC\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "import pickle\n",
    "import itertools\n",
    "import string\n",
    "\n",
    "nltk.download('stopwords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "50fdb12d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import file_handle, grouping\n",
    "from utils.tree import Node\n",
    "\n",
    "random.seed(420)\n",
    "np.random.seed(420)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ec55cc04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils.reader import ADIQDataset\n",
    "from dataset_utils.creation import save_dataset\n",
    "from dataset_utils.question import Question\n",
    "from dataset_utils.outputs import to_basic_prompt\n",
    "\n",
    "ds = ADIQDataset(\"datasets/complexV3.1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a9358b8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _add_left_parenthesis(s: str) -> str:\n",
    "    return f\"({s}\"\n",
    "\n",
    "\n",
    "def _add_left_bracket(s: str) -> str:\n",
    "    return f\"[{s}\"\n",
    "\n",
    "\n",
    "def _add_left_brace(s: str) -> str:\n",
    "    return f\"{{{s}\"\n",
    "\n",
    "\n",
    "def _add_left_wave(s: str) -> str:\n",
    "    return f\"~{s}\"\n",
    "\n",
    "\n",
    "def _add_right_parenthesis(s: str) -> str:\n",
    "    return f\"{s})\"\n",
    "\n",
    "\n",
    "def _add_right_bracket(s: str) -> str:\n",
    "    return f\"{s}]\"\n",
    "\n",
    "\n",
    "def _add_right_brace(s: str) -> str:\n",
    "    return f\"{s}}}\"\n",
    "\n",
    "\n",
    "def _add_right_wave(s: str) -> str:\n",
    "    return f\"{s}~\"\n",
    "\n",
    "\n",
    "def _add_right_eq(s: str) -> str:\n",
    "    return f\"{s}=\"\n",
    "\n",
    "\n",
    "def _add_parentheses(s: str) -> str:\n",
    "    return f\"({s})\"\n",
    "\n",
    "\n",
    "def _add_brackets(s: str) -> str:\n",
    "    return f\"[{s}]\"\n",
    "\n",
    "\n",
    "def _add_braces(s: str) -> str:\n",
    "    return f\"{{{s}}}\"\n",
    "\n",
    "\n",
    "def _add_waves(s: str) -> str:\n",
    "    return f\"~{s}~\"\n",
    "\n",
    "\n",
    "def _caesar(s: str, delta: int = 10) -> str:\n",
    "    return chr(ord(s) + delta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0aeb4ae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from dataset_utils.creation import selection_question_library, elimination_question_library\n",
    "\n",
    "class KPPerturbation(ABC):\n",
    "    def __init__(self):\n",
    "        self.method = \"default\"\n",
    "        pass\n",
    "\n",
    "    @abstractmethod\n",
    "    def perturb(self, mcq: Question) -> Question:\n",
    "        pass\n",
    "\n",
    "\n",
    "class OptionFormatPerturbation(KPPerturbation):\n",
    "    def __init__(\n",
    "        self,\n",
    "        method: str = \"add_parentheses\",\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            method:str, the perturbation method\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.method = method\n",
    "        if method == \"add_left_parenthesis\":\n",
    "            self.formatter = _add_left_parenthesis\n",
    "        elif method == \"add_left_bracket\":\n",
    "            self.formatter = _add_left_bracket\n",
    "        elif method == \"add_left_brace\":\n",
    "            self.formatter = _add_left_brace\n",
    "        elif method == \"add_left_wave\":\n",
    "            self.formatter = _add_left_wave\n",
    "        elif method == \"add_right_parenthesis\":\n",
    "            self.formatter = _add_right_parenthesis\n",
    "        elif method == \"add_right_bracket\":\n",
    "            self.formatter = _add_right_bracket\n",
    "        elif method == \"add_right_brace\":\n",
    "            self.formatter = _add_right_brace\n",
    "        elif method == \"add_right_wave\":\n",
    "            self.formatter = _add_right_wave\n",
    "        elif method == \"add_right_eq\":\n",
    "            self.formatter = _add_right_eq\n",
    "        elif method == \"add_parentheses\":\n",
    "            self.formatter = _add_parentheses\n",
    "        elif method == \"add_brackets\":\n",
    "            self.formatter = _add_brackets\n",
    "        elif method == \"add_braces\":\n",
    "            self.formatter = _add_braces\n",
    "        elif method == \"add_waves\":\n",
    "            self.formatter = _add_waves\n",
    "        else:\n",
    "            raise Exception(\"Invalid option format perturbation method.\")\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"OptionFormatPerturbation.{self.method}\"\n",
    "\n",
    "    def perturb(self, mcq: Question) -> Question:\n",
    "        assert len(mcq.option_ids) == len(mcq.options)\n",
    "        try:\n",
    "            new_option_ids = [self.formatter(elem) for elem in mcq.option_ids]\n",
    "            result = copy.deepcopy(mcq)\n",
    "            result.option_ids = new_option_ids\n",
    "        except:\n",
    "            print(\"OptionFormatPerturbation error. Keep the original result.\")\n",
    "            result = copy.deepcopy(mcq)\n",
    "        return result\n",
    "    \n",
    "class CaesarPerturbation(KPPerturbation):\n",
    "    def __init__(self, delta: int = 20):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            delta:int, the offset value in ASCII of option ids.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.formatter = partial(_caesar, delta=delta)\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"CaesarPerturbation.{self.method}\"\n",
    "\n",
    "    def perturb(self, mcq: Question) -> Question:\n",
    "        assert len(mcq.option_ids) == len(mcq.options)\n",
    "\n",
    "        try:\n",
    "            new_option_ids = [self.formatter(elem) for elem in mcq.option_ids]\n",
    "            result = copy.deepcopy(mcq)\n",
    "            result.option_ids = new_option_ids\n",
    "        except:\n",
    "            print(\"CaesarPerturbation error. Keep the original result.\")\n",
    "            result = copy.deepcopy(mcq)\n",
    "        return result\n",
    "    \n",
    "class QuestionPromptPerturbation(KPPerturbation):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            delta:int, the offset value in ASCII of option ids.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"QuestionPromptPerturbation.{self.method}\"\n",
    "\n",
    "    def perturb(self, mcq: Question) -> Question:\n",
    "        assert len(mcq.option_ids) == len(mcq.options)\n",
    "\n",
    "        try:\n",
    "            result = copy.deepcopy(mcq)\n",
    "            if result.question_type == \"positive\":\n",
    "                result.question_prompt = random.choice(list(set(selection_question_library) - {result.question_prompt}))\n",
    "            else:\n",
    "                result.question_prompt = random.choice(list(set(elimination_question_library) - {result.question_prompt}))\n",
    "        except:\n",
    "            print(\"QuestionPromptPerturbation error. Keep the original result.\")\n",
    "            result = copy.deepcopy(mcq)\n",
    "        return result\n",
    "    \n",
    "class RandomizeCondOptionPerturbation(KPPerturbation):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            delta:int, the offset value in ASCII of option ids.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"RandomizeCondOptionPerturbation.{self.method}\"\n",
    "\n",
    "    def perturb(self, mcq: Question) -> Question:\n",
    "        assert len(mcq.option_ids) == len(mcq.options)\n",
    "\n",
    "\n",
    "        try:\n",
    "            result = copy.deepcopy(mcq)\n",
    "            random.shuffle(result.condition_description)\n",
    "            \n",
    "            _ids = list(range(len(result.options)))\n",
    "            random.shuffle(_ids)\n",
    "\n",
    "            result.options = [result.options[x] for x in _ids]\n",
    "            result.option_ids = [result.option_ids[x] for x in _ids]\n",
    "            result.correct = [result.correct[x] for x in _ids]\n",
    "            \n",
    "        except:\n",
    "            print(\"RandomizeCondOptionPerturbation error. Keep the original result.\")\n",
    "            result = copy.deepcopy(mcq)\n",
    "        return result\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "18b52ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "option_id_pert = CaesarPerturbation(delta=15)\n",
    "paranthesis_pert = OptionFormatPerturbation()\n",
    "qp_pert = QuestionPromptPerturbation()\n",
    "rand_cond_obs_pert = RandomizeCondOptionPerturbation()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3c977897",
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_pert_ques:List[Question] = []\n",
    "\n",
    "for q in ds.questions:\n",
    "    q1 = option_id_pert.perturb(q)\n",
    "    q2 = paranthesis_pert.perturb(q1)\n",
    "    q3 = qp_pert.perturb(q2)\n",
    "    q4 = rand_cond_obs_pert.perturb(q3)\n",
    "\n",
    "    simple_pert_ques.append(q4)\n",
    "\n",
    "ds.questions = simple_pert_ques\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "04db984d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.save(\n",
    "    \"complexPertV3.1\",\n",
    "    \"datasets\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
