{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from copy import deepcopy\n",
    "import random\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "from convlab.util import load_ontology"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_coreference(state):\n",
    "    value2slot = {}\n",
    "    for domain in state:\n",
    "        for slot in state[domain]:\n",
    "            value = state[domain][slot]\n",
    "            value2slot.setdefault(value, [])\n",
    "            value2slot[value].append(f'{domain}-{slot}')\n",
    "    return {value: value2slot[value] for value in value2slot if len(set([d_s.split('-')[0] for d_s in value2slot[value]])) > 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "multi_domain_data = json.load(open('data/sgd/group0/multi_domain.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "slot_coref = {}\n",
    "for dial in multi_domain_data:\n",
    "    slot_coref_dial = set()\n",
    "    for turn in dial['turns']:\n",
    "        if 'state' in turn:\n",
    "            value2slot = find_coreference(turn['state'])\n",
    "            for value, ds_list in value2slot.items():\n",
    "                ds_list = tuple(sorted(ds_list))\n",
    "                t = (value, ds_list)\n",
    "                if t not in slot_coref_dial:\n",
    "                    slot_coref_dial.add(t)\n",
    "    for value, ds_list in slot_coref_dial:\n",
    "        slot_coref.setdefault(ds_list, [])\n",
    "        slot_coref[ds_list].append(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "93"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(slot_coref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{('attraction-area', 'hotel-area'): 227,\n",
       " ('hotel-book day', 'train-day'): 371,\n",
       " ('hotel-book people', 'train-book people'): 324,\n",
       " ('hotel-book people', 'hotel-book stay', 'train-book people'): 43,\n",
       " ('hotel-area', 'restaurant-area'): 349,\n",
       " ('restaurant-book day', 'train-day'): 486,\n",
       " ('hotel-stars', 'train-book people'): 22,\n",
       " ('hotel-book people', 'hotel-stars', 'train-book people'): 11,\n",
       " ('restaurant-book people', 'train-book people'): 390,\n",
       " ('attraction-area', 'restaurant-area'): 457,\n",
       " ('hotel-stars', 'restaurant-book people'): 22,\n",
       " ('hotel-price range', 'restaurant-price range'): 288,\n",
       " ('hotel-book people', 'restaurant-book people'): 356,\n",
       " ('hotel-book day', 'restaurant-book day'): 416,\n",
       " ('restaurant-book time', 'taxi-arrive by'): 533,\n",
       " ('restaurant-name', 'taxi-destination'): 519,\n",
       " ('attraction-name', 'taxi-departure'): 343,\n",
       " ('hotel-name', 'taxi-departure'): 454,\n",
       " ('attraction-name', 'taxi-destination'): 224,\n",
       " ('hotel-name', 'taxi-destination'): 213,\n",
       " ('restaurant-name', 'taxi-departure'): 135,\n",
       " ('restaurant-book time', 'taxi-leave at'): 21,\n",
       " ('attraction-area', 'hotel-price range'): 4,\n",
       " ('attraction-type', 'hotel-price range'): 3,\n",
       " ('attraction-type', 'hotel-area'): 2,\n",
       " ('restaurant-food', 'taxi-destination'): 2,\n",
       " ('attraction-area', 'restaurant-food'): 3,\n",
       " ('attraction-area', 'restaurant-price range'): 3,\n",
       " ('attraction-area', 'hotel-parking'): 2,\n",
       " ('hotel-book people', 'hotel-stars', 'restaurant-book people'): 13,\n",
       " ('hotel-book stay', 'restaurant-book people'): 13,\n",
       " ('hotel-book people', 'hotel-book stay', 'restaurant-book people'): 51,\n",
       " ('hotel-area', 'restaurant-price range'): 3,\n",
       " ('hotel-area', 'restaurant-food'): 3,\n",
       " ('hotel-name', 'restaurant-name'): 3,\n",
       " ('hotel-internet', 'hotel-parking', 'restaurant-price range'): 2,\n",
       " ('hotel-book people',\n",
       "  'hotel-book stay',\n",
       "  'hotel-stars',\n",
       "  'restaurant-book people'): 3,\n",
       " ('hotel-internet', 'hotel-price range', 'restaurant-area'): 2,\n",
       " ('hotel-price range', 'restaurant-area'): 4,\n",
       " ('hotel-name', 'restaurant-price range'): 2,\n",
       " ('hotel-internet', 'restaurant-area'): 2,\n",
       " ('hotel-price range', 'restaurant-food'): 2,\n",
       " ('attraction-area', 'train-arrive by'): 2,\n",
       " ('restaurant-book time', 'train-arrive by'): 4,\n",
       " ('restaurant-area', 'train-arrive by'): 2,\n",
       " ('restaurant-area', 'train-leave at'): 2,\n",
       " ('hotel-price range', 'train-leave at'): 4,\n",
       " ('hotel-book stay', 'train-book people'): 8,\n",
       " ('hotel-area', 'train-leave at'): 2,\n",
       " ('hotel-book people',\n",
       "  'hotel-book stay',\n",
       "  'hotel-stars',\n",
       "  'train-book people'): 5,\n",
       " ('hotel-area', 'train-arrive by'): 3}"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "{slot: len(value) for slot, value in slot_coref.items() if len(value)>1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_state_update(prev_state, cur_state):\n",
    "    state = deepcopy(cur_state)\n",
    "    for domain in prev_state:\n",
    "        state.setdefault(domain, {})\n",
    "        for slot in prev_state[domain]:\n",
    "            if slot not in state[domain]:\n",
    "                state[domain][slot] = ''\n",
    "            elif prev_state[domain][slot] == state[domain][slot]:\n",
    "                state[domain].pop(slot)\n",
    "        if len(state[domain]) == 0:\n",
    "            state.pop(domain)\n",
    "    return state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "single_domain_data = json.load(open('data/multiwoz21/single_domain.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)\n",
    "domain2slot2value = {}\n",
    "for dial in single_domain_data:\n",
    "    prev_state = {}\n",
    "    for turn in dial['turns']:\n",
    "        if 'state' in turn:\n",
    "            state_update = get_state_update(prev_state, turn['state'])\n",
    "            for domain in state_update:\n",
    "                domain2slot2value.setdefault(domain, {})\n",
    "                for slot, value in state_update[domain].items():\n",
    "                    domain2slot2value[domain].setdefault(slot, [])\n",
    "                    domain2slot2value[domain][slot].append(value)\n",
    "            prev_state = turn['state']\n",
    "num_sample_value = 10\n",
    "for domain in domain2slot2value:\n",
    "    for slot, value_set in domain2slot2value[domain].items():\n",
    "        domain2slot2value[domain][slot] = random.sample(value_set, min(num_sample_value, len(value_set)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Alarm_1</th>\n",
       "      <th>Banks_1</th>\n",
       "      <th>Buses_3</th>\n",
       "      <th>Calendar_1</th>\n",
       "      <th>Events_3</th>\n",
       "      <th>Flights_4</th>\n",
       "      <th>Homes_2</th>\n",
       "      <th>Hotels_4</th>\n",
       "      <th>Media_3</th>\n",
       "      <th>Movies_1</th>\n",
       "      <th>Music_3</th>\n",
       "      <th>Payment_1</th>\n",
       "      <th>RentalCars_3</th>\n",
       "      <th>Restaurants_1</th>\n",
       "      <th>RideSharing_2</th>\n",
       "      <th>Services_1</th>\n",
       "      <th>Trains_1</th>\n",
       "      <th>Travel_1</th>\n",
       "      <th>Weather_1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Alarm_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Banks_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Buses_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Calendar_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Events_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Flights_4</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Homes_2</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Hotels_4</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Media_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Movies_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Music_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Payment_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RentalCars_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Restaurants_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RideSharing_2</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Services_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Trains_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Travel_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Weather_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Alarm_1 Banks_1 Buses_3 Calendar_1 Events_3 Flights_4 Homes_2  \\\n",
       "Alarm_1           NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Banks_1           NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Buses_3           NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Calendar_1        NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Events_3          NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Flights_4         NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Homes_2           NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Hotels_4          NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Media_3           NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Movies_1          NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Music_3           NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Payment_1         NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "RentalCars_3      NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Restaurants_1     NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "RideSharing_2     NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Services_1        NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Trains_1          NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Travel_1          NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "Weather_1         NaN     NaN     NaN        NaN      NaN       NaN     NaN   \n",
       "\n",
       "              Hotels_4 Media_3 Movies_1 Music_3 Payment_1 RentalCars_3  \\\n",
       "Alarm_1            NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Banks_1            NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Buses_3            NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Calendar_1         NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Events_3           NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Flights_4          NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Homes_2            NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Hotels_4           NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Media_3            NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Movies_1           NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Music_3            NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Payment_1          NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "RentalCars_3       NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Restaurants_1      NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "RideSharing_2      NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Services_1         NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Trains_1           NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Travel_1           NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "Weather_1          NaN     NaN      NaN     NaN       NaN          NaN   \n",
       "\n",
       "              Restaurants_1 RideSharing_2 Services_1 Trains_1 Travel_1  \\\n",
       "Alarm_1                 NaN           NaN        NaN      NaN      NaN   \n",
       "Banks_1                 NaN           NaN        NaN      NaN      NaN   \n",
       "Buses_3                 NaN           NaN        NaN      NaN      NaN   \n",
       "Calendar_1              NaN           NaN        NaN      NaN      NaN   \n",
       "Events_3                NaN           NaN        NaN      NaN      NaN   \n",
       "Flights_4               NaN           NaN        NaN      NaN      NaN   \n",
       "Homes_2                 NaN           NaN        NaN      NaN      NaN   \n",
       "Hotels_4                NaN           NaN        NaN      NaN      NaN   \n",
       "Media_3                 NaN           NaN        NaN      NaN      NaN   \n",
       "Movies_1                NaN           NaN        NaN      NaN      NaN   \n",
       "Music_3                 NaN           NaN        NaN      NaN      NaN   \n",
       "Payment_1               NaN           NaN        NaN      NaN      NaN   \n",
       "RentalCars_3            NaN           NaN        NaN      NaN      NaN   \n",
       "Restaurants_1           NaN           NaN        NaN      NaN      NaN   \n",
       "RideSharing_2           NaN           NaN        NaN      NaN      NaN   \n",
       "Services_1              NaN           NaN        NaN      NaN      NaN   \n",
       "Trains_1                NaN           NaN        NaN      NaN      NaN   \n",
       "Travel_1                NaN           NaN        NaN      NaN      NaN   \n",
       "Weather_1               NaN           NaN        NaN      NaN      NaN   \n",
       "\n",
       "              Weather_1  \n",
       "Alarm_1             NaN  \n",
       "Banks_1             NaN  \n",
       "Buses_3             NaN  \n",
       "Calendar_1          NaN  \n",
       "Events_3            NaN  \n",
       "Flights_4           NaN  \n",
       "Homes_2             NaN  \n",
       "Hotels_4            NaN  \n",
       "Media_3             NaN  \n",
       "Movies_1            NaN  \n",
       "Music_3             NaN  \n",
       "Payment_1           NaN  \n",
       "RentalCars_3        NaN  \n",
       "Restaurants_1       NaN  \n",
       "RideSharing_2       NaN  \n",
       "Services_1          NaN  \n",
       "Trains_1            NaN  \n",
       "Travel_1            NaN  \n",
       "Weather_1           NaN  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "domains = sorted(list(domain2slot2value.keys()))\n",
    "df = pd.DataFrame([],index=domains,columns=domains)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ontology = load_ontology('sgd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('/zhangpai23/zhuqi/pre-trained-models/all-mpnet-base-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "domain2slot2embed = {}\n",
    "for domain in domains:\n",
    "    domain2slot2embed[domain] = {}\n",
    "    for slot in ontology['state'][domain]:\n",
    "        desc = ontology['domains'][domain]['slots'][slot]['description']\n",
    "        embed = model.encode(desc)\n",
    "        domain2slot2embed[domain][slot] = embed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "slot_sims = []\n",
    "for i in range(len(domains)):\n",
    "    embed_i = domain2slot2embed[domains[i]]\n",
    "    for j in range(i+1, len(domains)):\n",
    "        embed_j = domain2slot2embed[domains[j]]\n",
    "        sim_mat = {}\n",
    "        for slot_i in embed_i:\n",
    "            for slot_j in embed_j:\n",
    "                slot_pair = f'{domains[i]}-{slot_i}@{domains[j]}-{slot_j}'\n",
    "                sim_score = util.cos_sim(embed_i[slot_i], embed_j[slot_j]).item()\n",
    "                sim_mat[f'{slot_i}-{slot_j}'] = sim_score\n",
    "                slot_sims.append((slot_pair, sim_score))\n",
    "        df.iloc[i,j] = sorted(sim_mat.items(),key=lambda x: x[1],reverse=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Alarm_1</th>\n",
       "      <th>Banks_1</th>\n",
       "      <th>Buses_3</th>\n",
       "      <th>Calendar_1</th>\n",
       "      <th>Events_3</th>\n",
       "      <th>Flights_4</th>\n",
       "      <th>Homes_2</th>\n",
       "      <th>Hotels_4</th>\n",
       "      <th>Media_3</th>\n",
       "      <th>Movies_1</th>\n",
       "      <th>Music_3</th>\n",
       "      <th>Payment_1</th>\n",
       "      <th>RentalCars_3</th>\n",
       "      <th>Restaurants_1</th>\n",
       "      <th>RideSharing_2</th>\n",
       "      <th>Services_1</th>\n",
       "      <th>Trains_1</th>\n",
       "      <th>Travel_1</th>\n",
       "      <th>Weather_1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Alarm_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>[(alarm_name-recipient_account_name, [tensor([...</td>\n",
       "      <td>[(alarm_time-departure_time, [tensor([0.5261])...</td>\n",
       "      <td>[(alarm_time-event_time, [tensor([0.5312])]), ...</td>\n",
       "      <td>[(alarm_time-time, [tensor([0.5312])]), (alarm...</td>\n",
       "      <td>[(alarm_time-outbound_departure_time, [tensor(...</td>\n",
       "      <td>[(alarm_name-property_name, [tensor([0.3040])]...</td>\n",
       "      <td>[(alarm_name-place_name, [tensor([0.3084])]), ...</td>\n",
       "      <td>[(alarm_name-title, [tensor([0.3001])]), (new_...</td>\n",
       "      <td>[(alarm_time-show_time, [tensor([0.5145])]), (...</td>\n",
       "      <td>[(alarm_name-track, [tensor([0.3566])]), (alar...</td>\n",
       "      <td>[(new_alarm_name-receiver, [tensor([0.2461])])...</td>\n",
       "      <td>[(new_alarm_time-pickup_time, [tensor([0.4137]...</td>\n",
       "      <td>[(alarm_name-restaurant_name, [tensor([0.3168]...</td>\n",
       "      <td>[(alarm_time-wait_time, [tensor([0.3057])]), (...</td>\n",
       "      <td>[(alarm_time-appointment_time, [tensor([0.4614...</td>\n",
       "      <td>[(alarm_time-journey_start_time, [tensor([0.48...</td>\n",
       "      <td>[(new_alarm_name-attraction_name, [tensor([0.3...</td>\n",
       "      <td>[(alarm_name-city, [tensor([0.2700])]), (alarm...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Banks_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(amount-num_passengers, [tensor([0.3327])]), ...</td>\n",
       "      <td>[(recipient_account_name-event_name, [tensor([...</td>\n",
       "      <td>[(recipient_account_name-venue, [tensor([0.242...</td>\n",
       "      <td>[(amount-price, [tensor([0.3296])]), (recipien...</td>\n",
       "      <td>[(recipient_account_name-address, [tensor([0.2...</td>\n",
       "      <td>[(recipient_account_name-street_address, [tens...</td>\n",
       "      <td>[(recipient_account_type-genre, [tensor([0.258...</td>\n",
       "      <td>[(amount-number_of_tickets, [tensor([0.2898])]...</td>\n",
       "      <td>[(recipient_account_name-artist, [tensor([0.29...</td>\n",
       "      <td>[(recipient_account_name-receiver, [tensor([0....</td>\n",
       "      <td>[(recipient_account_type-car_type, [tensor([0....</td>\n",
       "      <td>[(recipient_account_name-street_address, [tens...</td>\n",
       "      <td>[(recipient_account_name-destination, [tensor(...</td>\n",
       "      <td>[(recipient_account_name-street_address, [tens...</td>\n",
       "      <td>[(amount-total, [tensor([0.2734])]), (recipien...</td>\n",
       "      <td>[(recipient_account_type-category, [tensor([0....</td>\n",
       "      <td>[(recipient_account_name-city, [tensor([0.2633...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Buses_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(departure_time-event_time, [tensor([0.5919])...</td>\n",
       "      <td>[(price-price_per_ticket, [tensor([0.7611])]),...</td>\n",
       "      <td>[(num_passengers-number_of_tickets, [tensor([0...</td>\n",
       "      <td>[(departure_date-visit_date, [tensor([0.5183])...</td>\n",
       "      <td>[(to_city-location, [tensor([0.5785])]), (depa...</td>\n",
       "      <td>[(to_city-title, [tensor([0.1951])]), (departu...</td>\n",
       "      <td>[(price-price, [tensor([0.8033])]), (num_passe...</td>\n",
       "      <td>[(to_station-device, [tensor([0.2964])]), (fro...</td>\n",
       "      <td>[(num_passengers-amount, [tensor([0.3776])]), ...</td>\n",
       "      <td>[(to_city-city, [tensor([0.4570])]), (departur...</td>\n",
       "      <td>[(departure_date-date, [tensor([0.4911])]), (n...</td>\n",
       "      <td>[(to_city-destination, [tensor([0.5701])]), (p...</td>\n",
       "      <td>[(departure_date-appointment_date, [tensor([0....</td>\n",
       "      <td>[(to_station-from_station, [tensor([0.8183])])...</td>\n",
       "      <td>[(to_city-location, [tensor([0.5863])]), (from...</td>\n",
       "      <td>[(to_city-city, [tensor([0.5500])]), (from_cit...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Calendar_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[(event_time-time, [tensor([1.])]), (event_loc...</td>\n",
       "      <td>[(event_time-departure_date, [tensor([0.6173])...</td>\n",
       "      <td>[(event_date-visit_date, [tensor([0.5202])]), ...</td>\n",
       "      <td>[(event_date-check_in_date, [tensor([0.6681])]...</td>\n",
       "      <td>[(event_name-title, [tensor([0.4413])]), (even...</td>\n",
       "      <td>[(event_time-show_time, [tensor([0.6069])]), (...</td>\n",
       "      <td>[(event_name-artist, [tensor([0.4081])]), (eve...</td>\n",
       "      <td>[(event_name-receiver, [tensor([0.2097])]), (e...</td>\n",
       "      <td>[(event_date-end_date, [tensor([0.3158])]), (e...</td>\n",
       "      <td>[(event_date-date, [tensor([0.7632])]), (event...</td>\n",
       "      <td>[(event_location-destination, [tensor([0.3419]...</td>\n",
       "      <td>[(event_date-appointment_date, [tensor([0.5968...</td>\n",
       "      <td>[(event_time-journey_start_time, [tensor([0.61...</td>\n",
       "      <td>[(event_location-location, [tensor([0.5171])])...</td>\n",
       "      <td>[(event_date-date, [tensor([0.4321])]), (event...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Events_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(price_per_ticket-price, [tensor([0.7486])]),...</td>\n",
       "      <td>[(venue_address-address, [tensor([0.6441])]), ...</td>\n",
       "      <td>[(number_of_tickets-number_of_rooms, [tensor([...</td>\n",
       "      <td>[(event_name-title, [tensor([0.4225])]), (even...</td>\n",
       "      <td>[(price_per_ticket-price, [tensor([0.9238])]),...</td>\n",
       "      <td>[(event_name-artist, [tensor([0.5871])]), (eve...</td>\n",
       "      <td>[(price_per_ticket-amount, [tensor([0.2941])])...</td>\n",
       "      <td>[(price_per_ticket-price_per_day, [tensor([0.3...</td>\n",
       "      <td>[(number_of_tickets-party_size, [tensor([0.685...</td>\n",
       "      <td>[(number_of_tickets-number_of_seats, [tensor([...</td>\n",
       "      <td>[(date-appointment_date, [tensor([0.5976])]), ...</td>\n",
       "      <td>[(number_of_tickets-number_of_adults, [tensor(...</td>\n",
       "      <td>[(city-location, [tensor([0.6118])]), (venue_a...</td>\n",
       "      <td>[(city-city, [tensor([0.5279])]), (date-date, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Flights_4</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(departure_date-visit_date, [tensor([0.6228])...</td>\n",
       "      <td>[(departure_date-check_in_date, [tensor([0.564...</td>\n",
       "      <td>[(seating_class-genre, [tensor([0.2168])]), (a...</td>\n",
       "      <td>[(price-price, [tensor([0.6755])]), (number_of...</td>\n",
       "      <td>[(departure_date-year, [tensor([0.2793])]), (d...</td>\n",
       "      <td>[(price-amount, [tensor([0.3588])]), (number_o...</td>\n",
       "      <td>[(return_date-end_date, [tensor([0.4634])]), (...</td>\n",
       "      <td>[(departure_date-date, [tensor([0.5904])]), (r...</td>\n",
       "      <td>[(price-ride_fare, [tensor([0.5492])]), (desti...</td>\n",
       "      <td>[(departure_date-appointment_date, [tensor([0....</td>\n",
       "      <td>[(destination_airport-from_station, [tensor([0...</td>\n",
       "      <td>[(destination_airport-location, [tensor([0.412...</td>\n",
       "      <td>[(destination_airport-city, [tensor([0.4895])]...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Homes_2</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(address-street_address, [tensor([0.6800])]),...</td>\n",
       "      <td>[(property_name-title, [tensor([0.2757])]), (p...</td>\n",
       "      <td>[(area-location, [tensor([0.4850])]), (address...</td>\n",
       "      <td>[(property_name-artist, [tensor([0.2492])]), (...</td>\n",
       "      <td>[(address-receiver, [tensor([0.3230])]), (prop...</td>\n",
       "      <td>[(area-city, [tensor([0.4658])]), (visit_date-...</td>\n",
       "      <td>[(phone_number-phone_number, [tensor([0.5414])...</td>\n",
       "      <td>[(address-destination, [tensor([0.3640])]), (a...</td>\n",
       "      <td>[(visit_date-appointment_date, [tensor([0.5784...</td>\n",
       "      <td>[(visit_date-date_of_journey, [tensor([0.4656]...</td>\n",
       "      <td>[(area-location, [tensor([0.5410])]), (phone_n...</td>\n",
       "      <td>[(area-city, [tensor([0.5535])]), (property_na...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Hotels_4</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(place_name-title, [tensor([0.2237])]), (smok...</td>\n",
       "      <td>[(street_address-street_address, [tensor([0.54...</td>\n",
       "      <td>[(place_name-artist, [tensor([0.2453])]), (che...</td>\n",
       "      <td>[(street_address-receiver, [tensor([0.3123])])...</td>\n",
       "      <td>[(location-city, [tensor([0.5493])]), (price_p...</td>\n",
       "      <td>[(check_in_date-date, [tensor([0.8381])]), (ch...</td>\n",
       "      <td>[(number_of_rooms-number_of_seats, [tensor([0....</td>\n",
       "      <td>[(check_in_date-appointment_date, [tensor([0.6...</td>\n",
       "      <td>[(number_of_rooms-number_of_adults, [tensor([0...</td>\n",
       "      <td>[(location-location, [tensor([0.6436])]), (pho...</td>\n",
       "      <td>[(location-city, [tensor([0.5304])]), (street_...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Media_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(title-movie_name, [tensor([0.8592])]), (titl...</td>\n",
       "      <td>[(title-track, [tensor([0.4742])]), (title-art...</td>\n",
       "      <td>[(genre-payment_method, [tensor([0.2725])]), (...</td>\n",
       "      <td>[(title-car_name, [tensor([0.2346])]), (genre-...</td>\n",
       "      <td>[(title-restaurant_name, [tensor([0.3741])]), ...</td>\n",
       "      <td>[(genre-ride_type, [tensor([0.2577])]), (genre...</td>\n",
       "      <td>[(title-stylist_name, [tensor([0.2406])]), (ti...</td>\n",
       "      <td>[(genre-class, [tensor([0.1590])]), (title-to_...</td>\n",
       "      <td>[(genre-category, [tensor([0.5094])]), (genre-...</td>\n",
       "      <td>[(title-city, [tensor([0.2877])]), (subtitle_l...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Movies_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(movie_name-track, [tensor([0.5609])]), (genr...</td>\n",
       "      <td>[(number_of_tickets-amount, [tensor([0.3567])]...</td>\n",
       "      <td>[(price-price_per_day, [tensor([0.4164])]), (l...</td>\n",
       "      <td>[(location-city, [tensor([0.5669])]), (street_...</td>\n",
       "      <td>[(number_of_tickets-number_of_seats, [tensor([...</td>\n",
       "      <td>[(location-city, [tensor([0.5239])]), (show_ti...</td>\n",
       "      <td>[(price-total, [tensor([0.6083])]), (number_of...</td>\n",
       "      <td>[(location-location, [tensor([0.6035])]), (str...</td>\n",
       "      <td>[(location-city, [tensor([0.5398])]), (theater...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Music_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(artist-receiver, [tensor([0.3562])]), (devic...</td>\n",
       "      <td>[(genre-car_type, [tensor([0.3369])]), (genre-...</td>\n",
       "      <td>[(artist-restaurant_name, [tensor([0.4061])]),...</td>\n",
       "      <td>[(genre-ride_type, [tensor([0.2799])]), (devic...</td>\n",
       "      <td>[(artist-stylist_name, [tensor([0.4598])]), (a...</td>\n",
       "      <td>[(device-from_station, [tensor([0.3036])]), (d...</td>\n",
       "      <td>[(device-location, [tensor([0.2507])]), (artis...</td>\n",
       "      <td>[(artist-city, [tensor([0.3426])]), (track-cit...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Payment_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(amount-price_per_day, [tensor([0.2821])]), (...</td>\n",
       "      <td>[(receiver-street_address, [tensor([0.3086])])...</td>\n",
       "      <td>[(amount-ride_fare, [tensor([0.3081])]), (rece...</td>\n",
       "      <td>[(receiver-street_address, [tensor([0.3170])])...</td>\n",
       "      <td>[(receiver-to_station, [tensor([0.2679])]), (a...</td>\n",
       "      <td>[(receiver-phone_number, [tensor([0.2243])]), ...</td>\n",
       "      <td>[(receiver-city, [tensor([0.2771])]), (payment...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RentalCars_3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(city-city, [tensor([0.4099])]), (start_date-...</td>\n",
       "      <td>[(city-destination, [tensor([0.4815])]), (pick...</td>\n",
       "      <td>[(end_date-appointment_date, [tensor([0.4060])...</td>\n",
       "      <td>[(end_date-date_of_journey, [tensor([0.4192])]...</td>\n",
       "      <td>[(city-location, [tensor([0.3945])]), (pickup_...</td>\n",
       "      <td>[(city-city, [tensor([0.4272])]), (end_date-da...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Restaurants_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(party_size-number_of_seats, [tensor([0.5245]...</td>\n",
       "      <td>[(date-appointment_date, [tensor([0.6524])]), ...</td>\n",
       "      <td>[(party_size-number_of_adults, [tensor([0.5328...</td>\n",
       "      <td>[(serves_alcohol-free_entry, [tensor([0.5922])...</td>\n",
       "      <td>[(city-city, [tensor([0.5751])]), (restaurant_...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RideSharing_2</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(destination-street_address, [tensor([0.3275]...</td>\n",
       "      <td>[(number_of_seats-number_of_adults, [tensor([0...</td>\n",
       "      <td>[(destination-location, [tensor([0.3454])]), (...</td>\n",
       "      <td>[(destination-city, [tensor([0.3477])]), (wait...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Services_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(appointment_date-date_of_journey, [tensor([0...</td>\n",
       "      <td>[(phone_number-phone_number, [tensor([0.5446])...</td>\n",
       "      <td>[(city-city, [tensor([0.4708])]), (appointment...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Trains_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(from_station-location, [tensor([0.4453])]), ...</td>\n",
       "      <td>[(to_station-city, [tensor([0.5665])]), (from_...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Travel_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[(location-city, [tensor([0.5348])]), (attract...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Weather_1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Alarm_1                                            Banks_1  \\\n",
       "Alarm_1           NaN  [(alarm_name-recipient_account_name, [tensor([...   \n",
       "Banks_1           NaN                                                NaN   \n",
       "Buses_3           NaN                                                NaN   \n",
       "Calendar_1        NaN                                                NaN   \n",
       "Events_3          NaN                                                NaN   \n",
       "Flights_4         NaN                                                NaN   \n",
       "Homes_2           NaN                                                NaN   \n",
       "Hotels_4          NaN                                                NaN   \n",
       "Media_3           NaN                                                NaN   \n",
       "Movies_1          NaN                                                NaN   \n",
       "Music_3           NaN                                                NaN   \n",
       "Payment_1         NaN                                                NaN   \n",
       "RentalCars_3      NaN                                                NaN   \n",
       "Restaurants_1     NaN                                                NaN   \n",
       "RideSharing_2     NaN                                                NaN   \n",
       "Services_1        NaN                                                NaN   \n",
       "Trains_1          NaN                                                NaN   \n",
       "Travel_1          NaN                                                NaN   \n",
       "Weather_1         NaN                                                NaN   \n",
       "\n",
       "                                                         Buses_3  \\\n",
       "Alarm_1        [(alarm_time-departure_time, [tensor([0.5261])...   \n",
       "Banks_1        [(amount-num_passengers, [tensor([0.3327])]), ...   \n",
       "Buses_3                                                      NaN   \n",
       "Calendar_1                                                   NaN   \n",
       "Events_3                                                     NaN   \n",
       "Flights_4                                                    NaN   \n",
       "Homes_2                                                      NaN   \n",
       "Hotels_4                                                     NaN   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                      Calendar_1  \\\n",
       "Alarm_1        [(alarm_time-event_time, [tensor([0.5312])]), ...   \n",
       "Banks_1        [(recipient_account_name-event_name, [tensor([...   \n",
       "Buses_3        [(departure_time-event_time, [tensor([0.5919])...   \n",
       "Calendar_1                                                  <NA>   \n",
       "Events_3                                                     NaN   \n",
       "Flights_4                                                    NaN   \n",
       "Homes_2                                                      NaN   \n",
       "Hotels_4                                                     NaN   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                        Events_3  \\\n",
       "Alarm_1        [(alarm_time-time, [tensor([0.5312])]), (alarm...   \n",
       "Banks_1        [(recipient_account_name-venue, [tensor([0.242...   \n",
       "Buses_3        [(price-price_per_ticket, [tensor([0.7611])]),...   \n",
       "Calendar_1     [(event_time-time, [tensor([1.])]), (event_loc...   \n",
       "Events_3                                                     NaN   \n",
       "Flights_4                                                    NaN   \n",
       "Homes_2                                                      NaN   \n",
       "Hotels_4                                                     NaN   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                       Flights_4  \\\n",
       "Alarm_1        [(alarm_time-outbound_departure_time, [tensor(...   \n",
       "Banks_1        [(amount-price, [tensor([0.3296])]), (recipien...   \n",
       "Buses_3        [(num_passengers-number_of_tickets, [tensor([0...   \n",
       "Calendar_1     [(event_time-departure_date, [tensor([0.6173])...   \n",
       "Events_3       [(price_per_ticket-price, [tensor([0.7486])]),...   \n",
       "Flights_4                                                    NaN   \n",
       "Homes_2                                                      NaN   \n",
       "Hotels_4                                                     NaN   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                         Homes_2  \\\n",
       "Alarm_1        [(alarm_name-property_name, [tensor([0.3040])]...   \n",
       "Banks_1        [(recipient_account_name-address, [tensor([0.2...   \n",
       "Buses_3        [(departure_date-visit_date, [tensor([0.5183])...   \n",
       "Calendar_1     [(event_date-visit_date, [tensor([0.5202])]), ...   \n",
       "Events_3       [(venue_address-address, [tensor([0.6441])]), ...   \n",
       "Flights_4      [(departure_date-visit_date, [tensor([0.6228])...   \n",
       "Homes_2                                                      NaN   \n",
       "Hotels_4                                                     NaN   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                        Hotels_4  \\\n",
       "Alarm_1        [(alarm_name-place_name, [tensor([0.3084])]), ...   \n",
       "Banks_1        [(recipient_account_name-street_address, [tens...   \n",
       "Buses_3        [(to_city-location, [tensor([0.5785])]), (depa...   \n",
       "Calendar_1     [(event_date-check_in_date, [tensor([0.6681])]...   \n",
       "Events_3       [(number_of_tickets-number_of_rooms, [tensor([...   \n",
       "Flights_4      [(departure_date-check_in_date, [tensor([0.564...   \n",
       "Homes_2        [(address-street_address, [tensor([0.6800])]),...   \n",
       "Hotels_4                                                     NaN   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                         Media_3  \\\n",
       "Alarm_1        [(alarm_name-title, [tensor([0.3001])]), (new_...   \n",
       "Banks_1        [(recipient_account_type-genre, [tensor([0.258...   \n",
       "Buses_3        [(to_city-title, [tensor([0.1951])]), (departu...   \n",
       "Calendar_1     [(event_name-title, [tensor([0.4413])]), (even...   \n",
       "Events_3       [(event_name-title, [tensor([0.4225])]), (even...   \n",
       "Flights_4      [(seating_class-genre, [tensor([0.2168])]), (a...   \n",
       "Homes_2        [(property_name-title, [tensor([0.2757])]), (p...   \n",
       "Hotels_4       [(place_name-title, [tensor([0.2237])]), (smok...   \n",
       "Media_3                                                      NaN   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                        Movies_1  \\\n",
       "Alarm_1        [(alarm_time-show_time, [tensor([0.5145])]), (...   \n",
       "Banks_1        [(amount-number_of_tickets, [tensor([0.2898])]...   \n",
       "Buses_3        [(price-price, [tensor([0.8033])]), (num_passe...   \n",
       "Calendar_1     [(event_time-show_time, [tensor([0.6069])]), (...   \n",
       "Events_3       [(price_per_ticket-price, [tensor([0.9238])]),...   \n",
       "Flights_4      [(price-price, [tensor([0.6755])]), (number_of...   \n",
       "Homes_2        [(area-location, [tensor([0.4850])]), (address...   \n",
       "Hotels_4       [(street_address-street_address, [tensor([0.54...   \n",
       "Media_3        [(title-movie_name, [tensor([0.8592])]), (titl...   \n",
       "Movies_1                                                     NaN   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                         Music_3  \\\n",
       "Alarm_1        [(alarm_name-track, [tensor([0.3566])]), (alar...   \n",
       "Banks_1        [(recipient_account_name-artist, [tensor([0.29...   \n",
       "Buses_3        [(to_station-device, [tensor([0.2964])]), (fro...   \n",
       "Calendar_1     [(event_name-artist, [tensor([0.4081])]), (eve...   \n",
       "Events_3       [(event_name-artist, [tensor([0.5871])]), (eve...   \n",
       "Flights_4      [(departure_date-year, [tensor([0.2793])]), (d...   \n",
       "Homes_2        [(property_name-artist, [tensor([0.2492])]), (...   \n",
       "Hotels_4       [(place_name-artist, [tensor([0.2453])]), (che...   \n",
       "Media_3        [(title-track, [tensor([0.4742])]), (title-art...   \n",
       "Movies_1       [(movie_name-track, [tensor([0.5609])]), (genr...   \n",
       "Music_3                                                      NaN   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                       Payment_1  \\\n",
       "Alarm_1        [(new_alarm_name-receiver, [tensor([0.2461])])...   \n",
       "Banks_1        [(recipient_account_name-receiver, [tensor([0....   \n",
       "Buses_3        [(num_passengers-amount, [tensor([0.3776])]), ...   \n",
       "Calendar_1     [(event_name-receiver, [tensor([0.2097])]), (e...   \n",
       "Events_3       [(price_per_ticket-amount, [tensor([0.2941])])...   \n",
       "Flights_4      [(price-amount, [tensor([0.3588])]), (number_o...   \n",
       "Homes_2        [(address-receiver, [tensor([0.3230])]), (prop...   \n",
       "Hotels_4       [(street_address-receiver, [tensor([0.3123])])...   \n",
       "Media_3        [(genre-payment_method, [tensor([0.2725])]), (...   \n",
       "Movies_1       [(number_of_tickets-amount, [tensor([0.3567])]...   \n",
       "Music_3        [(artist-receiver, [tensor([0.3562])]), (devic...   \n",
       "Payment_1                                                    NaN   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                    RentalCars_3  \\\n",
       "Alarm_1        [(new_alarm_time-pickup_time, [tensor([0.4137]...   \n",
       "Banks_1        [(recipient_account_type-car_type, [tensor([0....   \n",
       "Buses_3        [(to_city-city, [tensor([0.4570])]), (departur...   \n",
       "Calendar_1     [(event_date-end_date, [tensor([0.3158])]), (e...   \n",
       "Events_3       [(price_per_ticket-price_per_day, [tensor([0.3...   \n",
       "Flights_4      [(return_date-end_date, [tensor([0.4634])]), (...   \n",
       "Homes_2        [(area-city, [tensor([0.4658])]), (visit_date-...   \n",
       "Hotels_4       [(location-city, [tensor([0.5493])]), (price_p...   \n",
       "Media_3        [(title-car_name, [tensor([0.2346])]), (genre-...   \n",
       "Movies_1       [(price-price_per_day, [tensor([0.4164])]), (l...   \n",
       "Music_3        [(genre-car_type, [tensor([0.3369])]), (genre-...   \n",
       "Payment_1      [(amount-price_per_day, [tensor([0.2821])]), (...   \n",
       "RentalCars_3                                                 NaN   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                   Restaurants_1  \\\n",
       "Alarm_1        [(alarm_name-restaurant_name, [tensor([0.3168]...   \n",
       "Banks_1        [(recipient_account_name-street_address, [tens...   \n",
       "Buses_3        [(departure_date-date, [tensor([0.4911])]), (n...   \n",
       "Calendar_1     [(event_date-date, [tensor([0.7632])]), (event...   \n",
       "Events_3       [(number_of_tickets-party_size, [tensor([0.685...   \n",
       "Flights_4      [(departure_date-date, [tensor([0.5904])]), (r...   \n",
       "Homes_2        [(phone_number-phone_number, [tensor([0.5414])...   \n",
       "Hotels_4       [(check_in_date-date, [tensor([0.8381])]), (ch...   \n",
       "Media_3        [(title-restaurant_name, [tensor([0.3741])]), ...   \n",
       "Movies_1       [(location-city, [tensor([0.5669])]), (street_...   \n",
       "Music_3        [(artist-restaurant_name, [tensor([0.4061])]),...   \n",
       "Payment_1      [(receiver-street_address, [tensor([0.3086])])...   \n",
       "RentalCars_3   [(city-city, [tensor([0.4099])]), (start_date-...   \n",
       "Restaurants_1                                                NaN   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                   RideSharing_2  \\\n",
       "Alarm_1        [(alarm_time-wait_time, [tensor([0.3057])]), (...   \n",
       "Banks_1        [(recipient_account_name-destination, [tensor(...   \n",
       "Buses_3        [(to_city-destination, [tensor([0.5701])]), (p...   \n",
       "Calendar_1     [(event_location-destination, [tensor([0.3419]...   \n",
       "Events_3       [(number_of_tickets-number_of_seats, [tensor([...   \n",
       "Flights_4      [(price-ride_fare, [tensor([0.5492])]), (desti...   \n",
       "Homes_2        [(address-destination, [tensor([0.3640])]), (a...   \n",
       "Hotels_4       [(number_of_rooms-number_of_seats, [tensor([0....   \n",
       "Media_3        [(genre-ride_type, [tensor([0.2577])]), (genre...   \n",
       "Movies_1       [(number_of_tickets-number_of_seats, [tensor([...   \n",
       "Music_3        [(genre-ride_type, [tensor([0.2799])]), (devic...   \n",
       "Payment_1      [(amount-ride_fare, [tensor([0.3081])]), (rece...   \n",
       "RentalCars_3   [(city-destination, [tensor([0.4815])]), (pick...   \n",
       "Restaurants_1  [(party_size-number_of_seats, [tensor([0.5245]...   \n",
       "RideSharing_2                                                NaN   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                      Services_1  \\\n",
       "Alarm_1        [(alarm_time-appointment_time, [tensor([0.4614...   \n",
       "Banks_1        [(recipient_account_name-street_address, [tens...   \n",
       "Buses_3        [(departure_date-appointment_date, [tensor([0....   \n",
       "Calendar_1     [(event_date-appointment_date, [tensor([0.5968...   \n",
       "Events_3       [(date-appointment_date, [tensor([0.5976])]), ...   \n",
       "Flights_4      [(departure_date-appointment_date, [tensor([0....   \n",
       "Homes_2        [(visit_date-appointment_date, [tensor([0.5784...   \n",
       "Hotels_4       [(check_in_date-appointment_date, [tensor([0.6...   \n",
       "Media_3        [(title-stylist_name, [tensor([0.2406])]), (ti...   \n",
       "Movies_1       [(location-city, [tensor([0.5239])]), (show_ti...   \n",
       "Music_3        [(artist-stylist_name, [tensor([0.4598])]), (a...   \n",
       "Payment_1      [(receiver-street_address, [tensor([0.3170])])...   \n",
       "RentalCars_3   [(end_date-appointment_date, [tensor([0.4060])...   \n",
       "Restaurants_1  [(date-appointment_date, [tensor([0.6524])]), ...   \n",
       "RideSharing_2  [(destination-street_address, [tensor([0.3275]...   \n",
       "Services_1                                                   NaN   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                        Trains_1  \\\n",
       "Alarm_1        [(alarm_time-journey_start_time, [tensor([0.48...   \n",
       "Banks_1        [(amount-total, [tensor([0.2734])]), (recipien...   \n",
       "Buses_3        [(to_station-from_station, [tensor([0.8183])])...   \n",
       "Calendar_1     [(event_time-journey_start_time, [tensor([0.61...   \n",
       "Events_3       [(number_of_tickets-number_of_adults, [tensor(...   \n",
       "Flights_4      [(destination_airport-from_station, [tensor([0...   \n",
       "Homes_2        [(visit_date-date_of_journey, [tensor([0.4656]...   \n",
       "Hotels_4       [(number_of_rooms-number_of_adults, [tensor([0...   \n",
       "Media_3        [(genre-class, [tensor([0.1590])]), (title-to_...   \n",
       "Movies_1       [(price-total, [tensor([0.6083])]), (number_of...   \n",
       "Music_3        [(device-from_station, [tensor([0.3036])]), (d...   \n",
       "Payment_1      [(receiver-to_station, [tensor([0.2679])]), (a...   \n",
       "RentalCars_3   [(end_date-date_of_journey, [tensor([0.4192])]...   \n",
       "Restaurants_1  [(party_size-number_of_adults, [tensor([0.5328...   \n",
       "RideSharing_2  [(number_of_seats-number_of_adults, [tensor([0...   \n",
       "Services_1     [(appointment_date-date_of_journey, [tensor([0...   \n",
       "Trains_1                                                     NaN   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                        Travel_1  \\\n",
       "Alarm_1        [(new_alarm_name-attraction_name, [tensor([0.3...   \n",
       "Banks_1        [(recipient_account_type-category, [tensor([0....   \n",
       "Buses_3        [(to_city-location, [tensor([0.5863])]), (from...   \n",
       "Calendar_1     [(event_location-location, [tensor([0.5171])])...   \n",
       "Events_3       [(city-location, [tensor([0.6118])]), (venue_a...   \n",
       "Flights_4      [(destination_airport-location, [tensor([0.412...   \n",
       "Homes_2        [(area-location, [tensor([0.5410])]), (phone_n...   \n",
       "Hotels_4       [(location-location, [tensor([0.6436])]), (pho...   \n",
       "Media_3        [(genre-category, [tensor([0.5094])]), (genre-...   \n",
       "Movies_1       [(location-location, [tensor([0.6035])]), (str...   \n",
       "Music_3        [(device-location, [tensor([0.2507])]), (artis...   \n",
       "Payment_1      [(receiver-phone_number, [tensor([0.2243])]), ...   \n",
       "RentalCars_3   [(city-location, [tensor([0.3945])]), (pickup_...   \n",
       "Restaurants_1  [(serves_alcohol-free_entry, [tensor([0.5922])...   \n",
       "RideSharing_2  [(destination-location, [tensor([0.3454])]), (...   \n",
       "Services_1     [(phone_number-phone_number, [tensor([0.5446])...   \n",
       "Trains_1       [(from_station-location, [tensor([0.4453])]), ...   \n",
       "Travel_1                                                     NaN   \n",
       "Weather_1                                                    NaN   \n",
       "\n",
       "                                                       Weather_1  \n",
       "Alarm_1        [(alarm_name-city, [tensor([0.2700])]), (alarm...  \n",
       "Banks_1        [(recipient_account_name-city, [tensor([0.2633...  \n",
       "Buses_3        [(to_city-city, [tensor([0.5500])]), (from_cit...  \n",
       "Calendar_1     [(event_date-date, [tensor([0.4321])]), (event...  \n",
       "Events_3       [(city-city, [tensor([0.5279])]), (date-date, ...  \n",
       "Flights_4      [(destination_airport-city, [tensor([0.4895])]...  \n",
       "Homes_2        [(area-city, [tensor([0.5535])]), (property_na...  \n",
       "Hotels_4       [(location-city, [tensor([0.5304])]), (street_...  \n",
       "Media_3        [(title-city, [tensor([0.2877])]), (subtitle_l...  \n",
       "Movies_1       [(location-city, [tensor([0.5398])]), (theater...  \n",
       "Music_3        [(artist-city, [tensor([0.3426])]), (track-cit...  \n",
       "Payment_1      [(receiver-city, [tensor([0.2771])]), (payment...  \n",
       "RentalCars_3   [(city-city, [tensor([0.4272])]), (end_date-da...  \n",
       "Restaurants_1  [(city-city, [tensor([0.5751])]), (restaurant_...  \n",
       "RideSharing_2  [(destination-city, [tensor([0.3477])]), (wait...  \n",
       "Services_1     [(city-city, [tensor([0.4708])]), (appointment...  \n",
       "Trains_1       [(to_station-city, [tensor([0.5665])]), (from_...  \n",
       "Travel_1       [(location-city, [tensor([0.5348])]), (attract...  \n",
       "Weather_1                                                    NaN  "
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.iloc[3,3] = pd.NA\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('slot_desc_sim.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "domain2slot2embed = {}\n",
    "for domain in domains:\n",
    "    domain2slot2embed[domain] = {}\n",
    "    for slot in domain2slot2value[domain]:\n",
    "        values = domain2slot2value[domain][slot]\n",
    "        embeds = model.encode(values)\n",
    "        embed = np.mean(embeds,axis=0)\n",
    "        domain2slot2embed[domain][slot] = embed\n",
    "\n",
    "slot_sims = []\n",
    "for i in range(len(domains)):\n",
    "    embed_i = domain2slot2embed[domains[i]]\n",
    "    for j in range(i+1, len(domains)):\n",
    "        embed_j = domain2slot2embed[domains[j]]\n",
    "        sim_mat = {}\n",
    "        for slot_i in embed_i:\n",
    "            for slot_j in embed_j:\n",
    "                slot_pair = f'{domains[i]}-{slot_i}@{domains[j]}-{slot_j}'\n",
    "                sim_score = util.cos_sim(embed_i[slot_i], embed_j[slot_j]).item()\n",
    "                sim_mat[f'{slot_i}-{slot_j}'] = sim_score\n",
    "                slot_sims.append((slot_pair, sim_score))\n",
    "        df.iloc[i,j] = sorted(sim_mat.items(),key=lambda x: x[1],reverse=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('value_set_sim.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('Hotels_4-smoking_allowed', 'RentalCars_3-add_insurance') 3\n",
      "('Hotels_4-smoking_allowed', 'RentalCars_3-add_insurance') 3\n",
      "('Payment_1-private_visibility', 'Trains_1-trip_protection') 7\n",
      "('Payment_1-private_visibility', 'Trains_1-trip_protection') 7\n",
      "('Buses_3-additional_luggage', 'RentalCars_3-add_insurance') 38\n",
      "('Buses_3-additional_luggage', 'RentalCars_3-add_insurance') 38\n",
      "('Buses_3-additional_luggage', 'Travel_1-free_entry') 3\n",
      "('Buses_3-additional_luggage', 'Travel_1-free_entry') 3\n",
      "('Buses_3-additional_luggage', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 3\n",
      "('Buses_3-additional_luggage', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 3\n",
      "('Buses_3-additional_luggage', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 3\n",
      "('Hotels_4-smoking_allowed', 'Travel_1-good_for_kids') 6\n",
      "('Hotels_4-smoking_allowed', 'Travel_1-good_for_kids') 6\n",
      "('Buses_3-additional_luggage', 'Travel_1-good_for_kids') 2\n",
      "('Buses_3-additional_luggage', 'Travel_1-good_for_kids') 2\n",
      "('Hotels_4-smoking_allowed', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 4\n",
      "('Hotels_4-smoking_allowed', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 4\n",
      "('Hotels_4-smoking_allowed', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 4\n",
      "('Buses_3-additional_luggage', 'Hotels_4-smoking_allowed') 1\n",
      "('Buses_3-additional_luggage', 'Hotels_4-smoking_allowed') 1\n",
      "('Restaurants_1-serves_alcohol', 'Services_1-is_unisex') 7\n",
      "('Restaurants_1-serves_alcohol', 'Services_1-is_unisex') 7\n"
     ]
    }
   ],
   "source": [
    "for slot, value in slot_coref.items():\n",
    "    for s in slot:\n",
    "        if 'True' in value:\n",
    "        # if 'Buses_3' in s:\n",
    "            print(slot, len(value))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "slot_sims = sorted(slot_sims, key=lambda x:x[1],reverse=True)\n",
    "pd.DataFrame(slot_sims).to_csv('value_set_sim_all.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "slot_sims = sorted(slot_sims, key=lambda x:x[1],reverse=True)\n",
    "pd.DataFrame(slot_sims).to_csv('slot_desc_sim_all.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Alarm_1',\n",
       " 'Banks_1',\n",
       " 'Buses_3',\n",
       " 'Calendar_1',\n",
       " 'Events_3',\n",
       " 'Flights_4',\n",
       " 'Homes_2',\n",
       " 'Hotels_4',\n",
       " 'Media_3',\n",
       " 'Movies_1',\n",
       " 'Music_3',\n",
       " 'Payment_1',\n",
       " 'RentalCars_3',\n",
       " 'Restaurants_1',\n",
       " 'RideSharing_2',\n",
       " 'Services_1',\n",
       " 'Trains_1',\n",
       " 'Travel_1',\n",
       " 'Weather_1']"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 64-bit ('convlab')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "0d7e61334dfc0ef49fed574cd0889517bf66c7c88797d6df65d4f14c89b6fa83"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
