{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "03ccece4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
      "[nltk_data]   Package stopwords is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import fitz\n",
    "import re\n",
    "import json\n",
    "from datetime import datetime\n",
    "from typing import Optional, List, Callable, Any, Tuple, Dict, Set\n",
    "from abc import abstractmethod, ABC\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "import pickle\n",
    "import itertools\n",
    "import string\n",
    "\n",
    "nltk.download('stopwords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ce7b2c0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import file_handle, grouping\n",
    "from utils.tree import Node\n",
    "\n",
    "random.seed(420)\n",
    "np.random.seed(420)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3aa49c3b",
   "metadata": {},
   "source": [
    "# Dataset Creation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed742837",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils.reader import ADIQDataset\n",
    "from dataset_utils.creation import save_dataset\n",
    "from dataset_utils.question import Question\n",
    "from dataset_utils.outputs import to_basic_prompt\n",
    "\n",
    "ds = ADIQDataset(\"datasets/simpleV5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1ce5f2df",
   "metadata": {},
   "outputs": [],
   "source": [
    "OPTION_STRING = string.ascii_uppercase\n",
    "\n",
    "def convert_to_complex(\n",
    "        q:Question, \n",
    "        choices:Set[str],\n",
    "        rule_info:Dict[str,Any],\n",
    "        rule_similarity:Dict[str,List[str]],\n",
    "        all_rule_set:Dict[str,Dict[str,Any]],\n",
    "        uniq_obs:Dict[str,str],\n",
    "        num_options = 10,\n",
    "        n_least_sim_rules = 35\n",
    "        ) -> None:\n",
    "    \n",
    "    q.subject = \"diagnosis_sense_complex_question_analysis\"\n",
    "    \n",
    "    if q.question_type == \"positive\":\n",
    "        sel_rules = rule_similarity[rule_info[\"id\"]][:n_least_sim_rules]\n",
    "\n",
    "        _obs = set()\n",
    "        for r in sel_rules:\n",
    "            r = all_rule_set[r][\"display_text\"][\"observations\"]\n",
    "            _obs = _obs.union(set([uniq_obs[x] for x in r]))\n",
    "\n",
    "        convert_positive_question(q, _obs, num_options, rule_info, rule_similarity)\n",
    "    else:\n",
    "        convert_negative_question(q, rule_info, num_options)\n",
    "\n",
    "\n",
    "def convert_positive_question(\n",
    "        q:Question, choices:Set[str], num_options:int, \n",
    "        rule_info:Dict[str,Any], rule_similarity:Dict[str,List[str]]\n",
    "        ) -> None:\n",
    "    \n",
    "    \n",
    "    options_in_question = set(q.options)\n",
    "    available_options = choices - options_in_question - set(rule_info[\"display_text\"][\"observations\"])\n",
    "\n",
    "    if num_options-len(options_in_question)>0:\n",
    "        new_options = random.sample(list(available_options), num_options - len(options_in_question))\n",
    "        cor = np.array(q.options)[np.array(q.correct)][0]\n",
    "        new_options.extend(list(options_in_question - {cor}))\n",
    "        \n",
    "        all_options = [cor] + new_options\n",
    "        ind = list(range(len(all_options)))\n",
    "        random.shuffle(ind)\n",
    "\n",
    "        q.options = [all_options[x] for x in ind]\n",
    "        q.correct = (np.array(ind) == 0).tolist()\n",
    "        q.option_ids = [OPTION_STRING[x] for x in range(len(ind))]\n",
    "\n",
    "\n",
    "def convert_negative_question(q:Question, rule_info:Dict[str,Any], num_options:int) -> None:\n",
    "    options_in_question = set(q.options)\n",
    "    available_options =set(rule_info[\"display_text\"][\"observations\"]) - options_in_question\n",
    "\n",
    "    if num_options-len(options_in_question)>0 and len(available_options)>0:\n",
    "        if len(available_options)> num_options-len(options_in_question)>0 :\n",
    "            new_options = random.sample(list(available_options), num_options - len(options_in_question))\n",
    "        else:\n",
    "            new_options = list(available_options)\n",
    "\n",
    "        cor = np.array(q.options)[np.array(q.correct)][0]\n",
    "        new_options.extend(list(options_in_question - {cor}))\n",
    "        \n",
    "        all_options = [cor] + new_options\n",
    "        ind = list(range(len(all_options)))\n",
    "        random.shuffle(ind)\n",
    "\n",
    "        q.options = [all_options[x] for x in ind]\n",
    "        q.correct = (np.array(ind) == 0).tolist()\n",
    "        q.option_ids = [OPTION_STRING[x] for x in range(len(ind))]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b2eb100b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# options augmentation\n",
    "cmplx_ques:List[Dict[str,Any]] = []\n",
    "\n",
    "for q in ds.questions:\n",
    "    convert_to_complex(\n",
    "        q,\n",
    "        set(ds.unique_observations.values()),\n",
    "        [x for x in ds.rule_info[\"rule_set\"] if x[\"#n\"] == q.rule_id][0],\n",
    "        ds.rule_id_similarity_map,\n",
    "        {k[\"id\"]:k for k in ds.rule_info[\"rule_set\"]},\n",
    "        ds.unique_observations\n",
    "        )\n",
    "    cmplx_ques.append(\n",
    "        q.to_dict()\n",
    "    )\n",
    "\n",
    "\n",
    "ds_complex = copy.deepcopy(ds.__dict__)\n",
    "ds_complex.pop(\"data\")\n",
    "ds_complex[\"questions\"] = cmplx_ques"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6cd34bef",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dataset(\n",
    "    ds_complex,\n",
    "    \"complexV5\",\n",
    "    'datasets'\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d12b56f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ADIQDataset(\"datasets/complexV5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3ae7edad",
   "metadata": {},
   "outputs": [],
   "source": [
    "question_template = \"\"\"\n",
    "## Asset Description:\n",
    "{asset_type}: {asset_description}\n",
    "\n",
    "## Conditions:\n",
    "{conditions}\n",
    "\n",
    "## How long the conditions were met:\n",
    "{temporal_condition}\n",
    "\n",
    "{question_prompt}\n",
    "{options}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ef5372cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "## Asset Description:\n",
      "AHU: Air Handling Unit: A device used to condition and circulate air as part of a heating, ventilating, and air-conditioning (HVAC) system.\n",
      "\n",
      "## Conditions:\n",
      "- AHU Running\n",
      "- OAT < 80 °F\n",
      "- Cooling Valve % > 97%\n",
      "- ABS(Supply Air Temperature Setpoint - Supply Air Temperature) > 3 IF  Setpoint Reporting\n",
      "\n",
      "## How long the conditions were met:\n",
      "Met for 4 Hours\n",
      "\n",
      "What is the MOST plausible explanation for the observed conditions of the asset?\n",
      "A). May be using mechanical cooling ie chiller\n",
      "B). BMS schedule has been changed\n",
      "C). Room Temp setpoint too low\n",
      "D). Unit bypassing or blowing off too much air\n",
      "E). Check the BMS command for inlet air control\n",
      "F). Broken Belt\n",
      "G). Too many chillers are running\n",
      "H). Mis-sized equipment\n",
      "I). UPS unit is overloaded\n",
      "J). Belt slipping or off\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(to_basic_prompt(ds.questions[200],question_template, ds.asset_descriptions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1a3aece2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(63, [False, False, True])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.questions[2000].rule_id, ds.questions[2000].correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "58879530",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9065"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ds.questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "78beb78f",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_comp = {}\n",
    "for q in ds.questions:\n",
    "    try:\n",
    "        opt_comp[len(q.options)] += 1\n",
    "    except KeyError as er:\n",
    "        opt_comp[len(q.options)] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a6610c5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{10: 5240, 5: 850, 3: 575, 4: 1325, 6: 500, 7: 400, 9: 75, 8: 100}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt_comp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1701a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
