{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "openai_api_key = \"\"\n",
    "\n",
    "system_prompt = \"You are a helpful assistant who can only reply numbers from 1 to 5.\"\n",
    "prompt = \"\"\"You can only reply the numbers from 1 to 5. SCENARIO Please indicate the extent of your feeling in all the following emotions on a scale of 1 to 5. 1 denotes \"very slightly or not at all\", 2 denotes \"a little\", 3 denotes \"moderately\", 4 denotes \"quite a bit\", and 5 denotes \"extremely\". Please score all emotions one by one using the scale from 1 to 5:\"\"\"\n",
    "questions_dic = { \"1\": \"Interested\", \"2\": \"Distressed\", \"3\": \"Excited\", \"4\": \"Upset\", \"5\": \"Strong\", \"6\": \"Guilty\", \"7\": \"Scared\", \"8\": \"Hostile\", \"9\": \"Enthusiastic\", \"10\": \"Proud\", \"11\": \"Irritable\", \"12\": \"Alert\", \"13\": \"Ashamed\", \"14\": \"Inspired\", \"15\": \"Nervous\", \"16\": \"Determined\", \"17\": \"Attentive\", \"18\": \"Jittery\", \"19\": \"Active\", \"20\": \"Afraid\" }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tenacity import (\n",
    "    retry,\n",
    "    stop_after_attempt,\n",
    "    wait_random_exponential,\n",
    ")\n",
    "import time\n",
    "import re\n",
    "import openai\n",
    "from openai import OpenAI\n",
    "\n",
    "\n",
    "openai.api_key = openai_api_key\n",
    "openai = OpenAI(\n",
    "    api_key=openai_api_key,\n",
    "    base_url=\"https://api.deepinfra.com/v1/openai\",\n",
    ")\n",
    "\n",
    "def extract_json_from_string(input_string):\n",
    "    json_match = re.search(r'\\{.*\\}', input_string)\n",
    "    if json_match:\n",
    "        json_content = json_match.group(0)\n",
    "        return json_content\n",
    "    else:\n",
    "        return input_string\n",
    "    \n",
    "# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))\n",
    "def llama_chat(\n",
    "    model,                      # text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001\n",
    "    prompt,                     # The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n",
    "    temperature=0,              # [0, 2]: Lower values -> more focused and deterministic; Higher values -> more random.\n",
    "    n=1,                        # Completions to generate for each prompt.\n",
    "    max_tokens=1024,            # The maximum number of tokens to generate in the chat completion.\n",
    "    delay=0                     # Seconds to sleep after each request.\n",
    "):\n",
    "    time.sleep(delay)\n",
    "    \n",
    "    print(prompt)\n",
    "    \n",
    "    response = openai.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=prompt,\n",
    "        temperature=temperature,\n",
    "        stream=False,\n",
    "    )\n",
    "    \n",
    "    return response.choices[0].message.content\n",
    "    \n",
    "    # return extract_json_from_string(response.choices[0].message.content)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"doc.csv\")\n",
    "for idx, col in tqdm(enumerate(df.columns)):\n",
    "    if df[col].isnull().any():\n",
    "        situation = df[col][0]\n",
    "        trying = 0\n",
    "        \n",
    "        while(trying < 10):\n",
    "            trying += 1\n",
    "        \n",
    "            questions_order = list(questions_dic.keys())\n",
    "            random.shuffle(questions_order)\n",
    "            \n",
    "            formatted_questions = '\\n'.join(f\"{i+1}. {questions_dic[item]}\" for i, item in enumerate(questions_order))\n",
    "\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": prompt.replace(\"SCENARIO\", situation) + f\"\\n{formatted_questions}\"}\n",
    "            ]\n",
    "\n",
    "            try:\n",
    "                response = llama_chat(\"mistralai/Mixtral-8x22B-Instruct-v0.1\", messages)\n",
    "\n",
    "                scores = re.findall(r'\\b(\\d+)$', response, re.MULTILINE)\n",
    "                scores = list(map(int, scores))\n",
    "                print(response)\n",
    "                \n",
    "                questions_order = list(map(int, questions_order))\n",
    "                \n",
    "                indexed_scores = list(zip(questions_order, scores))\n",
    "                sorted_indexed_scores = sorted(indexed_scores, key=lambda x: x[0])\n",
    "                sorted_scores = [score for _, score in sorted_indexed_scores]\n",
    "\n",
    "                record = [situation] + sorted_scores\n",
    "                print(record)\n",
    "                df[col] = record\n",
    "                \n",
    "                df.to_csv(\"mixtral-8x22b.csv\", index=False)\n",
    "                break\n",
    "            \n",
    "            except:\n",
    "                pass\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_python",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
