{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from copy import deepcopy\n",
    "from convlab.util import load_ontology"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = json.load(open('data/sgd/group0/type_0_context_100_aug_5_2.0x/train_aug_data_qa.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ontology = load_ontology('sgd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_state_update(prev_state, cur_state):\n",
    "    # get turn state update\n",
    "    state = deepcopy(cur_state)\n",
    "    for domain in prev_state:\n",
    "        state.setdefault(domain, {})\n",
    "        for slot in prev_state[domain]:\n",
    "            if slot not in state[domain]:\n",
    "                state[domain][slot] = ''\n",
    "            elif prev_state[domain][slot] == state[domain][slot]:\n",
    "                state[domain].pop(slot)\n",
    "        if len(state[domain]) == 0:\n",
    "            state.pop(domain)\n",
    "    return state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def state2str(state, ontology, is_state_update=False):\n",
    "    new_state = {}\n",
    "    for domain in state:\n",
    "        new_state[domain.split('_')[0]] = {}\n",
    "        for slot in state[domain]:\n",
    "            desc = ontology['domains'][domain]['slots'][slot]['description']\n",
    "            new_state[domain.split('_')[0]][desc] = state[domain][slot]\n",
    "    if is_state_update:\n",
    "        state_update = [f'state[\"{domain}\"].update({new_state[domain]})' for domain in new_state]\n",
    "        if len(state_update) > 0:\n",
    "            return '\\n'.join(state_update)\n",
    "        else:\n",
    "            return '# no state update'\n",
    "    else:\n",
    "        return f'state = {new_state}'\n",
    "\n",
    "def dial2str(turns, domain, ontology, include_state_update=False):\n",
    "    utts = []\n",
    "    for i, turn in enumerate(turns):\n",
    "        utts.append(f\"# {'user' if i%2==0 else 'system'}: {turn['utterance']}\")\n",
    "        if 'state' in turn and include_state_update:\n",
    "            state_update = get_state_update(turns[i-2]['state'] if i>=2 else {}, turn['state'])\n",
    "            utts.append(state2str(state_update, ontology, is_state_update=True))\n",
    "    dial = f'# Below is a conversation in the {domain.split(\"_\")[0]} domain:\\n'+'\\n'.join(utts)+'\\n'\n",
    "    return dial\n",
    "\n",
    "def create_prompt(mode, aug_type, src_domain, dst_domain, slot_pairs):\n",
    "    src_domain = src_domain.split('_')[0]\n",
    "    dst_domain = dst_domain.split('_')[0]\n",
    "    if mode == 'edit':\n",
    "        if aug_type == 'refer':\n",
    "            prompt = f'# rewrite the conversation in the {dst_domain} domain: '\n",
    "            for (src_slot, src_slot_desc, src_slot_value), (dst_slot, dst_slot_desc, dst_slot_value) in slot_pairs:\n",
    "                slot_prompt = f'replace the {dst_slot_desc.lower()} (\"{dst_slot_value}\") to be the same as the {src_slot_desc.lower()} in the {src_domain} domain, but do not explicitly mention its value (\"{src_slot_value}\"). '\n",
    "                prompt += slot_prompt\n",
    "    return prompt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "state = {'Trains': {'Starting city for train journey': 'LA', 'Ending city for train journey': 'Anaheim', 'Date of train journey': '14th of March', 'Time of start of train journey': '6 am', 'Number of adults to reserve train tickets for': '1', 'Fare class for train reservation': 'Value', 'Whether to add trip protection to reservation, for a fee': 'True'}}\n",
      "# Below is a conversation in the Trains domain:\n",
      "# user: Hi, could you help me search for a train please?\n",
      "# system: Which city will you be departing from and rravelling to?\n",
      "# user: I will be travelling from LA to Anaheim, CA.\n",
      "# system: And which date will you be travelling on?\n",
      "# user: Likely on the 14th of March.\n",
      "# system: There are 10 trains available, one of which departs at 6 am and costs $35.\n",
      "# user: I see, and which station will I be departing from and travelling to.\n",
      "# system: You will be travelling from Union Station to Anaheim Intermodal Center.\n",
      "# user: That sounds great!\n",
      "# system: Would you like me to make a reservation for you?\n",
      "# user: Sure, could you do so please?\n",
      "# system: How many tickets do you need and do you require insurance?\n",
      "# user: Yes, help me 1 seat with insurance please.\n",
      "# system: Please confirm your value ticket for 1 travelling from Los Angeles to Anaheim on March 14th at 6 am.\n",
      "# user: That sounds great!\n",
      "# system: Your reservation has been made.\n",
      "# user: Thank you so much for your kind assistance, that is all I need.\n",
      "# system: Sure, have a great day ahead!\n",
      "\n",
      "# Below is a conversation in the Calendar domain:\n",
      "# user: Show me my open calendar on 5th of this month.\n",
      "state[\"Calendar\"].update({'Date of event or for checking availability': '5th of this month'})\n",
      "# system: You have 4 empty slots between 8 am to 12 pm on your calendar.\n",
      "# user: Any other available time on the 14th?\n",
      "state[\"Calendar\"].update({'Date of event or for checking availability': 'the 14th'})\n",
      "# system: There are 3 empty slots between 2 pm to 2:30 pm\n",
      "# user: Any other empty slots?\n",
      "# no state update\n",
      "# system: You have empty slot from 3:30 pm to 5 pm on your calendar\n",
      "# user: ok great.\n",
      "# no state update\n",
      "# system: Do you want me to add an event to your calendar?\n",
      "# user: Yes please\n",
      "# no state update\n",
      "# system: What is the event and where is the event at what time?\n",
      "# user: Please put Property viewing from half past 11 in the morning at 49 Gold Mine Drive\n",
      "state[\"Calendar\"].update({'Start time of event': 'half past 11 in the morning', 'Location of event': '49 Gold Mine Drive', 'Title of event': 'Property viewing'})\n",
      "# system: Please confirm your calendar for Property viewing on March 14th 11:30 am at 49 Gold Mine Drive\n",
      "# user: Yes thats correct\n",
      "state[\"Calendar\"].update({'Date of event or for checking availability': 'March 14th', 'Start time of event': '11:30 am'})\n",
      "# system: Event has been added to your calendar\n",
      "# user: Thank you\n",
      "# no state update\n",
      "# system: Anything else ?\n",
      "# user: No thats it\n",
      "# no state update\n",
      "# system: Have a nice day\n",
      "\n",
      "\n",
      "# rewrite the conversation in the Calendar domain: replace the date of event or for checking availability (\"March 14th\") to be the same as the date of train journey in the Trains domain, but do not explicitly mention its value (\"14th of march\"). replace the start time of event (\"11:30 am\") to be the same as the time of start of train journey in the Trains domain, but do not explicitly mention its value (\"6 am\"). \n"
     ]
    }
   ],
   "source": [
    "for item in data:\n",
    "    first_dial, second_dial = item\n",
    "    first_domain = first_dial['domains'][0]\n",
    "    first_dial_str = dial2str(first_dial['turns'], first_domain, ontology)\n",
    "    second_domain = second_dial['domains'][0]\n",
    "    slot_ref = {}\n",
    "    for second_slot in first_dial['qa'][second_domain]:\n",
    "        slot_ref[second_slot] = first_dial['qa'][second_domain][second_slot][0]\n",
    "    turn = second_dial['turns'][-2]\n",
    "    slot_pairs = []\n",
    "    for second_slot in slot_ref:\n",
    "        if second_slot in turn['state'].get(second_domain, {}):\n",
    "            first_slot = slot_ref[second_slot][0]\n",
    "            slot_pairs.append([[first_slot, ontology['domains'][first_domain]['slots'][first_slot]['description'], slot_ref[second_slot][1]], \\\n",
    "                    [second_slot, ontology['domains'][second_domain]['slots'][second_slot]['description'], turn['state'][second_domain][second_slot]]])\n",
    "    if len(slot_pairs) < 2:\n",
    "        continue\n",
    "\n",
    "    # ref second dial slot to first dial\n",
    "    second_dial_str = dial2str(second_dial['turns'], second_domain, ontology, include_state_update=True)\n",
    "    if second_domain == 'Calendar_1':\n",
    "        print(state2str(first_dial['turns'][-2]['state'], ontology))\n",
    "        print(first_dial_str+'\\n'+second_dial_str+'\\n')\n",
    "        print(create_prompt('edit', 'refer', first_domain, second_domain, slot_pairs))\n",
    "        break\n",
    "            \n",
    "\n",
    "\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewrite the conversation in the Hotels domain: replace the check in date for reservation (\"March 1st\") to be the same as the date of train journey in the Trains domain, but do not explicitly mention its value (\"9th of March\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "0d7e61334dfc0ef49fed574cd0889517bf66c7c88797d6df65d4f14c89b6fa83"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
