{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/zhangpai23/zhuqi/anaconda3/envs/convlab/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from convlab.util import load_dataset\n",
    "from pprint import pprint\n",
    "import queue\n",
    "from copy import deepcopy\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_statistics(dataset_name):\n",
    "    data = load_dataset(dataset_name)\n",
    "    table = []\n",
    "    domain_cnt = {}\n",
    "    for data_split in data:\n",
    "        for dial in data[data_split]:\n",
    "            if 'police' in dial['domains'] or 'hospital' in dial['domains']:\n",
    "                continue\n",
    "            domains = sorted(set(dial['domains']) - set(['general']))\n",
    "            domain_cnt.setdefault(tuple(domains), {'train': 0, 'validation': 0, 'test': 0})\n",
    "            domain_cnt[tuple(domains)][data_split] += 1\n",
    "    for domains, stat in sorted(domain_cnt.items(), key=lambda x:len(x[0])*10000+sum(x[1].values())):\n",
    "        s = sum(stat.values())\n",
    "        if s < 10:\n",
    "            continue\n",
    "        res = {'domains':domains, 'all': s}\n",
    "        for data_split in data:\n",
    "            res[data_split] = stat[data_split]\n",
    "        table.append(res)\n",
    "    return table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = get_statistics('sgd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Alarm': {'Alarm_1': 84},\n",
      " 'Banks': {'Banks_1': 207, 'Banks_2': 42},\n",
      " 'Buses': {'Buses_1': 195, 'Buses_2': 159, 'Buses_3': 88},\n",
      " 'Calendar': {'Calendar_1': 169},\n",
      " 'Events': {'Events_1': 289, 'Events_2': 572, 'Events_3': 76},\n",
      " 'Flights': {'Flights_1': 800,\n",
      "             'Flights_2': 185,\n",
      "             'Flights_3': 94,\n",
      "             'Flights_4': 87},\n",
      " 'Homes': {'Homes_1': 349, 'Homes_2': 89},\n",
      " 'Hotels': {'Hotels_1': 142, 'Hotels_2': 304, 'Hotels_3': 129, 'Hotels_4': 115},\n",
      " 'Media': {'Media_1': 281, 'Media_2': 46, 'Media_3': 80},\n",
      " 'Movies': {'Movies_1': 376, 'Movies_2': 47, 'Movies_3': 48},\n",
      " 'Music': {'Music_1': 98, 'Music_2': 331, 'Music_3': 25},\n",
      " 'Payment': {'Payment_1': 36},\n",
      " 'RentalCars': {'RentalCars_1': 143, 'RentalCars_2': 111, 'RentalCars_3': 64},\n",
      " 'Restaurants': {'Restaurants_1': 367, 'Restaurants_2': 146},\n",
      " 'RideSharing': {'RideSharing_1': 106, 'RideSharing_2': 92},\n",
      " 'Services': {'Services_1': 265,\n",
      "              'Services_2': 185,\n",
      "              'Services_3': 188,\n",
      "              'Services_4': 124},\n",
      " 'Trains': {'Trains_1': 84},\n",
      " 'Travel': {'Travel_1': 69},\n",
      " 'Weather': {'Weather_1': 83}}\n"
     ]
    }
   ],
   "source": [
    "domains =  {}\n",
    "multi_domain = {}\n",
    "for res in table:\n",
    "    if len(res['domains']) == 1:\n",
    "        service = res['domains'][0]\n",
    "        domain = service.split('_')[0]\n",
    "        domains.setdefault(domain, {})\n",
    "        domains[domain][service] = res['all']\n",
    "    else:\n",
    "        multi_domain[tuple(res['domains'])] = res['all']\n",
    "pprint(domains)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_multi_domain(combination):\n",
    "    multi = 0\n",
    "    cnt = 0\n",
    "    for comb in multi_domain:\n",
    "        if all([s in combination for s in comb]):\n",
    "            multi += multi_domain[comb]\n",
    "            cnt += 1\n",
    "    return cnt, multi\n",
    "\n",
    "def get_single_domain(combination):\n",
    "    single = 0\n",
    "    for service in combination:\n",
    "        domain = service.split('_')[0]\n",
    "        single += domains[domain][service]\n",
    "    return single\n",
    "    \n",
    "def remove_combination(all_services, combination):\n",
    "    new_all_services = deepcopy(all_services)\n",
    "    for services in new_all_services:\n",
    "        if len(services) > 1:\n",
    "            for service in services:\n",
    "                if service in combination:\n",
    "                    services.remove(service)\n",
    "                    break\n",
    "    return new_all_services"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_combinations = []\n",
    "def dfs(i, services, combination):\n",
    "    if i == len(services):\n",
    "        num_comb, num_multi = get_multi_domain(combination)\n",
    "        if num_comb >= 42:\n",
    "            num_single = get_single_domain(combination)\n",
    "            all_combinations.append((deepcopy(combination),num_comb,num_multi,num_single))\n",
    "        return\n",
    "    for service in services[i]:\n",
    "        combination.add(service)\n",
    "        dfs(i+1, services, combination)\n",
    "        combination.remove(service)\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7577"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_services = [list(domains[domain].keys()) for domain in domains]\n",
    "dfs(0,all_services,set())\n",
    "len(all_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7577/7577 [00:16<00:00, 452.17it/s] \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "942"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q = queue.PriorityQueue()\n",
    "for i in tqdm(range(len(all_combinations))):\n",
    "    comb1,num_comb1,num_multi1,num_single1 = all_combinations[i]\n",
    "    for j in range(i+1,len(all_combinations)):\n",
    "        comb2,num_comb2,num_multi2,num_single2 = all_combinations[j]\n",
    "        if len(comb1&comb2) == 6:\n",
    "            prior = abs(num_multi1-num_multi2)+abs(num_single1-num_single2)\n",
    "            q.put((prior, deepcopy(comb1), deepcopy(comb2), num_comb1,num_multi1,num_single1, num_comb2,num_multi2,num_single2))\n",
    "q.qsize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(376,\n",
       " {'Alarm_1',\n",
       "  'Banks_1',\n",
       "  'Buses_3',\n",
       "  'Calendar_1',\n",
       "  'Events_3',\n",
       "  'Flights_4',\n",
       "  'Homes_2',\n",
       "  'Hotels_4',\n",
       "  'Media_3',\n",
       "  'Movies_1',\n",
       "  'Music_3',\n",
       "  'Payment_1',\n",
       "  'RentalCars_3',\n",
       "  'Restaurants_1',\n",
       "  'RideSharing_2',\n",
       "  'Services_1',\n",
       "  'Trains_1',\n",
       "  'Travel_1',\n",
       "  'Weather_1'},\n",
       " {'Alarm_1',\n",
       "  'Banks_2',\n",
       "  'Buses_1',\n",
       "  'Calendar_1',\n",
       "  'Events_2',\n",
       "  'Flights_3',\n",
       "  'Homes_1',\n",
       "  'Hotels_1',\n",
       "  'Media_2',\n",
       "  'Movies_3',\n",
       "  'Music_1',\n",
       "  'Payment_1',\n",
       "  'RentalCars_1',\n",
       "  'Restaurants_2',\n",
       "  'RideSharing_1',\n",
       "  'Services_4',\n",
       "  'Trains_1',\n",
       "  'Travel_1',\n",
       "  'Weather_1'},\n",
       " 42,\n",
       " 3387,\n",
       " 2456,\n",
       " 42,\n",
       " 3185,\n",
       " 2630)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = q.get()\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 (conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "0d7e61334dfc0ef49fed574cd0889517bf66c7c88797d6df65d4f14c89b6fa83"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
