{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3b6154a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_321944/4288864223.py:15: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from tqdm.autonotebook import tqdm\n",
      "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
      "[nltk_data]   Package stopwords is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "import fitz\n",
    "import re\n",
    "import json\n",
    "from datetime import datetime\n",
    "from typing import Optional, List, Callable, Any, Tuple, Dict\n",
    "from abc import abstractmethod, ABC\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "import pickle\n",
    "from tqdm.autonotebook import tqdm\n",
    "import itertools\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "load_dotenv(dotenv_path=\"../.env\")\n",
    "nltk.download('stopwords')\n",
    "\n",
    "random.seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9545398a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils.reader import ADIQDataset\n",
    "from dataset_utils.outputs import to_basic_prompt\n",
    "\n",
    "ds = ADIQDataset(\"datasets/simpleV3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dc19ecb",
   "metadata": {},
   "source": [
    "## Incontext example sampling groups "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "30740bd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "conditions_of_questions = {x.id:x.condition_description for x in ds.questions}\n",
    "conditions_of_questions = {k:\"\\n\".join(v) for k,v in conditions_of_questions.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e6fd71e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer, util\n",
    "\n",
    "\n",
    "model = SentenceTransformer(\"all-mpnet-base-v2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "50758c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = model.encode(list(conditions_of_questions.values()))\n",
    "qcond_index_map = {k:v for v,k in enumerate(conditions_of_questions.keys())}\n",
    "index_qcond_map = {v:k for k,v in qcond_index_map.items()}\n",
    "\n",
    "norm = np.linalg.norm(enc, axis=1)\n",
    "norm = np.outer(norm,norm) + 1e-6\n",
    "\n",
    "cosine_sim = (enc @ enc.T)/norm\n",
    "cosine_sim = np.abs(cosine_sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b583eb3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 15183.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.985 , len(16)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 17208.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.986 , len(20)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 19190.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.987 , len(25)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 19537.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.988 , len(33)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 11836.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.989 , len(54)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 26875.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.99 , len(62)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 27300.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.991 , len(73)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 38846.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.992 , len(384)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 40526.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.993 , len(763)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 58495.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.994 , len(947)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 120714.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.995 , len(1596)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 222305.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.996 , len(2370)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 270785.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.997 , len(3096)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 393640.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.998 , len(3802)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 550539.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r:0.999 , len(4376)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from utils import grouping\n",
    "from utils import file_handle\n",
    "\n",
    "for r in range(985,1000):\n",
    "    _thresh = r/1000\n",
    "\n",
    "    _groups = {}\n",
    "    for v in range(cosine_sim.shape[0]):\n",
    "        thresh = np.quantile(cosine_sim[v].flatten(),_thresh)\n",
    "        _groups[v] = set(np.where(cosine_sim[v]>thresh)[0].tolist())-{v}\n",
    "\n",
    "\n",
    "    groups = {}\n",
    "    for k,v in _groups.items():\n",
    "        groups[index_qcond_map[k]] = set([index_qcond_map[x] for x in v])\n",
    "\n",
    "    edges = set()\n",
    "    for k,v in tqdm(groups.items()):\n",
    "        for n in v:\n",
    "            if k != n: \n",
    "                n_min,n_max = min(k,n), max(k,n)\n",
    "                edges.add((n_min, n_max))\n",
    "\n",
    "    group_set = grouping.getComponents(len(groups.keys()), edges)\n",
    "\n",
    "    print(\"r:{} , len({})\".format(_thresh, len(group_set)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c73b7dc7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6690/6690 [00:00<00:00, 20098.37it/s]\n"
     ]
    }
   ],
   "source": [
    "_thresh = 0.988\n",
    "\n",
    "_groups = {}\n",
    "for v in range(cosine_sim.shape[0]):\n",
    "    thresh = np.quantile(cosine_sim[v].flatten(),_thresh)\n",
    "    _groups[v] = set(np.where(cosine_sim[v]>thresh)[0].tolist())-{v}\n",
    "\n",
    "\n",
    "groups = {}\n",
    "for k,v in _groups.items():\n",
    "    groups[index_qcond_map[k]] = set([index_qcond_map[x] for x in v])\n",
    "\n",
    "edges = set()\n",
    "for k,v in tqdm(groups.items()):\n",
    "    for n in v:\n",
    "        if k != n: \n",
    "            n_min,n_max = min(k,n), max(k,n)\n",
    "            edges.add((n_min, n_max))\n",
    "\n",
    "group_set = grouping.getComponents(len(groups.keys()), edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "038b783a",
   "metadata": {},
   "outputs": [],
   "source": [
    "incontext_expl = random.sample([random.choice(x) for x in group_set], 10)\n",
    "\n",
    "#file_handle.save_json(\n",
    "#    {\"example_id\":incontext_expl},\n",
    "#    \"incontext_expl.json\"\n",
    "#)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "38dbb0a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils.reader import ADIQDataset\n",
    "from dataset_utils.outputs import to_basic_prompt\n",
    "from utils import file_handle\n",
    "\n",
    "ds = ADIQDataset(\"datasets/simpleV3.1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13ccba4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "incontext_expl = file_handle.load_json(\n",
    "    \"incontext_expl.json\"\n",
    ")['example_id']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "42ad5684",
   "metadata": {},
   "outputs": [],
   "source": [
    "question_template = \"\"\"\n",
    "## Asset Description:\n",
    "{asset_type}: {asset_description}\n",
    "\n",
    "## Conditions:\n",
    "{conditions}\n",
    "\n",
    "## Conditions in Natural Language:\n",
    "\"\"\"\n",
    "\n",
    "main_template = \"\"\"\n",
    "Your task is to read the asset description (## Asset Description:) and conditions (## Conditions:) applied on the asset and \n",
    "write the conditions (## Conditions:) in natural language  several examples are provided complete the last last sample.\n",
    "\n",
    "{examples}\n",
    "{question}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56fdebf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils.question import Question\n",
    "\n",
    "def condition_list(q:Question) -> str:\n",
    "    qlist = question_template.format(\n",
    "        asset_type = q.asset_type,\n",
    "        asset_description = ds.asset_descriptions[q.asset_type],\n",
    "        conditions = \"\\n\".join(q.condition_description)\n",
    "    )\n",
    "\n",
    "    return qlist\n",
    "\n",
    "def verberlizer_incontext(q:Question, incontext_expl:str) -> str:\n",
    "    _q = question_template.format(\n",
    "        asset_type = q.asset_type,\n",
    "        asset_description = ds.asset_descriptions.get(q.asset_type,\"NONE\"),\n",
    "        conditions = \"\\n\".join(q.condition_description)\n",
    "    )\n",
    "\n",
    "    _template = main_template.format(\n",
    "        examples = incontext_expl,\n",
    "        question = _q\n",
    "    )\n",
    "\n",
    "    return _template"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f7ccb244",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "qid_map = {k.id:v for v,k in enumerate(ds.questions)}\n",
    "\n",
    "for incon in incontext_expl:\n",
    "    print(condition_list(\n",
    "        ds.questions[qid_map[incon]]\n",
    "        ), end=\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea74a147",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 105/6690 [05:01<4:00:30,  2.19s/it]"
     ]
    }
   ],
   "source": [
    "import models_utils.llm.watsonx as watsonx\n",
    "\n",
    "_examples = file_handle.load_text(\"condition_verberlizer_examples.txt\")\n",
    "\n",
    "for q in tqdm(ds.questions):\n",
    "    _prompt = verberlizer_incontext(q,_examples)\n",
    "    response = watsonx.get_completion_response(\n",
    "        _prompt\n",
    "    )\n",
    "\n",
    "    q.verberlized_conditions = watsonx.clean_response(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87bc24c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The heat exchanger is off and the building 91 is alerted if the plate and frame status is off for more than one unit that is not POK. Additionally, the outside air wet-bulb temperature is less than 38 degrees Fahrenheit.'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.save(\n",
    "    \"simpleV3.1\",\n",
    "    \"datasets\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "87323d72",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.concurrency import concurrent_dict_execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa45f07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 1 1\n",
      "2 2 2\n",
      "4 2 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Completed:: 100%|██████████| 3/3 [00:00<00:00, 5461.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 (1, 1, 1)\n",
      "3 (16, 4, 49)\n",
      "1 (4, 4, 4)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98847f88",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
