{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "import os\n",
    "from openai import OpenAI\n",
    "from langchain_openai import ChatOpenAI\n",
    "from pydantic import BaseModel\n",
    "from langchain.schema import HumanMessage\n",
    "\n",
    "from Utils.Prompt import human_query_generation_agent\n",
    "\n",
    "OPENAI_API_KEY = os.getenv('OPEN_AI_API')\n",
    "\n",
    "def create_prompt():\n",
    "    count = 0\n",
    "\n",
    "    #day\n",
    "    day_list = ['2 day','3 days', '4 days']\n",
    "    day = random.choice(day_list)\n",
    "    count += 1\n",
    "\n",
    "    #price\n",
    "    price_list = [\"cheap budget\",\"moderate budget\",\"expensive budget\"]\n",
    "    price = random.choice(price_list)\n",
    "    count += 1\n",
    "\n",
    "    #attraction orientation\n",
    "    attraction_orientation_list = [\"family oriented\",\"history oriented\",\"activity oriented\",\"nature oriented\",\"food oriented\",\"shopping oriented\"]\n",
    "    attraction_orientation = random.choice(attraction_orientation_list)\n",
    "    count += 1\n",
    "\n",
    "    #cuisine\n",
    "    cuisine_list = [\"US\",\"Mexican\",\"Irish\",\"French\",\"Italian\",\"Greek\",\"Indian\",\"Chinese\",\"Japanese\",\"Korean\",\"Vietnamese\",\"Thai\",\"Asian Fusion\",\"Middle Eastern\"]\n",
    "    cuisine = random.choice(cuisine_list)\n",
    "    count += 1\n",
    "\n",
    "    #restaurant rating\n",
    "    weights = [0.6, 0.3, 0.1]  # weights for 1, 2, 3, and 4 respectively\n",
    "    samples = [1, 2, 3]\n",
    "    coins = random.choices(samples, weights=weights)[0]\n",
    "    n_list = [\"good flavor\",\"good freshness\",\"good service\",\"good environment\",\"good value\"]\n",
    "    restaurant_n = random.sample(n_list, coins)\n",
    "    restaurant_n_str = ', '.join(restaurant_n)\n",
    "    count += coins\n",
    "    \n",
    "    #hotel rating\n",
    "    weights = [0.6, 0.3, 0.1]  # weights for 1, 2, 3, and 4 respectively\n",
    "    samples = [1, 2, 3]\n",
    "    coins = random.choices(samples, weights=weights)[0]\n",
    "    n_list = [\"good quality\",\"good location\",\"good service\",\"good safety\"]\n",
    "    hotel_n = random.sample(n_list, coins)\n",
    "    hotel_n_str = ', '.join(hotel_n)\n",
    "    count += coins\n",
    "\n",
    "    #parse the prompt\n",
    "    prompt = {\n",
    "        \"general\":day + \", \" + price,\n",
    "        \"attraction\":attraction_orientation,\n",
    "        \"restaurant\":cuisine + (', ' if len(restaurant_n) > 0 else \"\") + restaurant_n_str,\n",
    "        \"hotel\":hotel_n_str\n",
    "    }\n",
    "\n",
    "    eval_info = {\n",
    "        \"day\":[day],\n",
    "        \"price\":[price],\n",
    "        \"attraction\":[attraction_orientation],\n",
    "        \"cuisine\":[cuisine],\n",
    "        \"restaurant\":restaurant_n,\n",
    "        \"hotel\":hotel_n,\n",
    "        \"preference_count\":[count]\n",
    "    }\n",
    "\n",
    "    return prompt,eval_info\n",
    "\n",
    "class QueryGenerationAgent():\n",
    "    def __init__(self):\n",
    "        self.llm = ChatOpenAI(temperature=0,\n",
    "                        model_name='gpt-4o-2024-11-20',\n",
    "                        openai_api_key=OPENAI_API_KEY)\n",
    "        \n",
    "    def generateHumanQuery(self, input, generation_agent):\n",
    "        content = generation_agent.format(input = input)\n",
    "        request = self.llm.invoke([HumanMessage(content)]).content\n",
    "        return request\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    genAgent = QueryGenerationAgent()\n",
    "    humanQuerys = []\n",
    "    evals = []\n",
    "\n",
    "    for i in range(500):\n",
    "        prompt_input,eval_info = create_prompt()\n",
    "        evals.append({\"index\":i+1,\"eval_info\":eval_info})\n",
    "        humanQuery = genAgent.generateHumanQuery(prompt_input,human_query_generation_agent)\n",
    "        humanQuerys.append({\"index\":i+1,\"query\":humanQuery})\n",
    "\n",
    "    with open('Prompts/humanQuerys.jsonl','w') as file:\n",
    "        for humanQuery in humanQuerys:\n",
    "            json.dump(humanQuery, file)\n",
    "            file.write('\\n')\n",
    "\n",
    "    with open('Prompts/evals.jsonl','w') as file:\n",
    "        for eval in evals:\n",
    "            json.dump(eval, file)\n",
    "            file.write('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchgpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
