{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a4d55df2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
      "[nltk_data]   Package stopwords is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import fitz\n",
    "import sys\n",
    "import re\n",
    "import json\n",
    "from datetime import datetime\n",
    "from typing import Optional, List, Callable, Any, Tuple, Dict\n",
    "from abc import abstractmethod, ABC\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "import pickle\n",
    "import itertools\n",
    "from dataclasses import dataclass, asdict\n",
    "from enum import Enum\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "load_dotenv(dotenv_path=\"../.env\")\n",
    "nltk.download('stopwords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8be38832",
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_LOC = \"rationale/saved_rationale_simple_contrast_prompt.jsonl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7988f1bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = \"You are a assistant that analyses industrial asset health and guides humans to maintain said assets\"\n",
    "\n",
    "question_template = \"\"\"\n",
    "### Asset Description:\n",
    "{asset_type}: {asset_description}\n",
    "\n",
    "### Conditions:\n",
    "{conditions}\n",
    "\n",
    "### How long the conditions were met:\n",
    "{temporal_condition}\n",
    "\n",
    "{question_prompt}\n",
    "{options}\n",
    "\"\"\"\n",
    "\n",
    "rational_template = \"\"\"\n",
    "Generate detailed asset rationales for guidance (”Guidance:”) based on the asset description (”Asset Description:”) \n",
    "and conditions (”Conditions:”) shown by the asset. These rationales should be the crucial cue for the guidance. \n",
    "Pretend that you don’t know the guidance (“Guidance:”). Please generate as a single paragraph of text with proper line break with '\\\\n' after (\"Guidance Rationale:\")\n",
    "\n",
    "# Example 1\n",
    "{example1}\n",
    "\n",
    "# Example 2\n",
    "{example2}\n",
    "\n",
    "# Example 3\n",
    "{example3}\n",
    "\n",
    "# Question\n",
    "{question}\n",
    "\n",
    "# Answer\n",
    "{answer}\n",
    "\n",
    "# Guidance Rationale:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d3eca8f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset_utils.reader import ADIQDataset\n",
    "\n",
    "ds = ADIQDataset(\"../dataset/datasets/simpleV3.1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5c1f3f99",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import file_handle\n",
    "\n",
    "rated_examples = file_handle.load_json(\"rationale/rated_examples.json\")\n",
    "flatten_rated_examples = [x for v in rated_examples.values() for x in v]\n",
    "flatten_rated_examples = sorted(flatten_rated_examples, key= lambda x: x['rating'], reverse=True)\n",
    "\n",
    "\n",
    "def select_examples(asset_type, examples=rated_examples):\n",
    "    local_flatten = copy.deepcopy(flatten_rated_examples)\n",
    "    try:\n",
    "        examples_asset_type = examples[asset_type]\n",
    "    except KeyError as ke:\n",
    "        print(\"No Examples found for Type:\", asset_type)\n",
    "        examples_asset_type = []\n",
    "\n",
    "    num_samples = len(examples_asset_type)\n",
    "    \n",
    "    if num_samples>3:\n",
    "        return random.sample(examples_asset_type,3)\n",
    "    elif num_samples == 3:\n",
    "        return examples_asset_type\n",
    "    else:\n",
    "        num_extra = 3 - num_samples\n",
    "        if num_samples>0:\n",
    "            sel_ids = [x[\"id\"] for x in examples_asset_type]\n",
    "            local_flatten = [x for x in local_flatten if x[\"id\"] not in sel_ids]\n",
    "\n",
    "        extra = random.sample(local_flatten[:5], num_extra)\n",
    "        examples_asset_type.extend(extra)\n",
    "        return examples_asset_type "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a86e9d7b",
   "metadata": {},
   "source": [
    "### Retrieve Answers"
   ]
  },
  {
   "cell_type": "raw",
   "id": "a895922f",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "from models_utils.utils.concurrency import concurrent_dict_execution\n",
    "from tqdm import tqdm\n",
    "from benchmarking.bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP\n",
    "\n",
    "\n",
    "model_config = ModelConfig(**{\n",
    "        \"name\":'mistral-large',\n",
    "        \"identifier\" : 'mistralai/mistral-large',\n",
    "    })\n",
    "client = LLMConfiguration(model_config)\n",
    "\n",
    "def generate_rationale_for_question(que:Question):\n",
    "    que_prompt = question_template.format(\n",
    "        asset_type = que.asset_type,\n",
    "        asset_description = ds.asset_descriptions.get(que.asset_type, \"NONE\"),\n",
    "        conditions = \"\\n\".join(list(map(lambda x:\"- \"+x, que.condition_description))),\n",
    "        temporal_condition = que.temporal_condition[0] if len(que.temporal_condition)>0 else \"NONE\",\n",
    "        question_prompt = que.question_prompt,\n",
    "        options = \"\\n\".join([\"{}. {}\".format(op_id, op) for op_id, op in zip(que.option_ids,que.options)]))\n",
    "    \n",
    "    _ind_cor = [i for i,x in enumerate(que.correct) if x][0]\n",
    "    answer = f\"\\nAnswer: {que.option_ids[_ind_cor]}. {que.answer_str}\\n\"\n",
    "\n",
    "    [ex1,ex2,ex3] = select_examples(que.asset_type)    \n",
    "\n",
    "    rational_example = rational_template.format(\n",
    "        example1 = ex1['text'],\n",
    "        example2 = ex2['text'],\n",
    "        example3 = ex3['text'],\n",
    "        question = que_prompt,\n",
    "        answer = answer\n",
    "    )\n",
    "\n",
    "    response = client.get_response(rational_example)\n",
    "\n",
    "    if response:\n",
    "        file_handle.save_jsonl({\n",
    "            \"id\":que.id,\n",
    "            \"full_id\":que.question_id,\n",
    "            \"examples\":[ex1,ex2,ex3],\n",
    "            \"prompt\":rational_example,\n",
    "            \"model_config\":{**model_config.to_dict()},\n",
    "            \"rationale\":response\n",
    "        }, SAVE_LOC)\n",
    "    else:\n",
    "        print(\"None response\")\n",
    "        print(rational_example)\n",
    "        raise ValueError(\"please check\")\n",
    "\n",
    "\n",
    "if os.path.exists(SAVE_LOC):\n",
    "    keys = [x[\"id\"] for x in file_handle.load_jsonl_generator(SAVE_LOC)]\n",
    "else:\n",
    "    keys = []\n",
    "\n",
    "params = {q.id:[q] for q in ds.questions if q.id not in keys}\n",
    "\n",
    "if params:\n",
    "    {k:v for k,v in concurrent_dict_execution(generate_rationale_for_question, params)}\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bb9e4123",
   "metadata": {},
   "outputs": [],
   "source": [
    "rationals_loaded = file_handle.load_jsonl(SAVE_LOC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "82d3a786",
   "metadata": {},
   "outputs": [],
   "source": [
    "for q in ds.questions:\n",
    "    q.rationale = rationals_loaded[q.id][\"rationale\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "42135a5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.save(\n",
    "    \"simpleV3.3\",\n",
    "    \"datasets\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c73b99ef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
