{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0816a8e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "from tqdm.notebook import tqdm\n",
    "from numpy.random import default_rng\n",
    "from itertools import chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6d74336",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(m, rng, mean=50, variance=30):\n",
    "    data = rng.normal(mean, variance, m)\n",
    "    data = np.clip(data, 1, None)\n",
    "    return data\n",
    "\n",
    "rng = default_rng()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321eb4dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_combinations_with_element(iterable, r, element):\n",
    "    pool = tuple(iterable)\n",
    "    n = len(pool)\n",
    "    if r > n:\n",
    "        return\n",
    "    indices = list(range(r))\n",
    "    yield frozenset(chain((element,), (pool[i] for i in indices)))\n",
    "    while True:\n",
    "        for i in reversed(range(r)):\n",
    "            if indices[i] != i + n - r:\n",
    "                break\n",
    "        else:\n",
    "            return\n",
    "        indices[i] += 1\n",
    "        for j in range(i+1, r):\n",
    "            indices[j] = indices[j-1] + 1\n",
    "        yield frozenset(chain((element,), (pool[i] for i in indices)))\n",
    "        \n",
    "def gen_powerset_with_element(iterable, element):\n",
    "    return chain.from_iterable(\n",
    "        gen_combinations_with_element(iterable, r, element)\n",
    "        for r in range(len(iterable)+1)\n",
    "    )\n",
    "\n",
    "def gen_combinations(iterable, r):\n",
    "    pool = tuple(iterable)\n",
    "    n = len(pool)\n",
    "    if r > n:\n",
    "        return\n",
    "    indices = list(range(r))\n",
    "    yield frozenset(pool[i] for i in indices)\n",
    "    while True:\n",
    "        for i in reversed(range(r)):\n",
    "            if indices[i] != i + n - r:\n",
    "                break\n",
    "        else:\n",
    "            return\n",
    "        indices[i] += 1\n",
    "        for j in range(i+1, r):\n",
    "            indices[j] = indices[j-1] + 1\n",
    "        yield frozenset(pool[i] for i in indices)\n",
    "        \n",
    "def gen_powerset(iterable):\n",
    "    return chain.from_iterable(\n",
    "        gen_combinations(iterable, r) for r in range(len(iterable)+1)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf4f5dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_firms_profits(coalitions, data, g, b):\n",
    "    coalitions_data = np.full(len(data), 1, dtype=float)\n",
    "    for coalition in coalitions:\n",
    "        shared_data = 0\n",
    "        for firm in coalition:\n",
    "            shared_data += data[firm]\n",
    "        for firm in coalition:\n",
    "            coalitions_data[firm] = shared_data\n",
    "    inv_coalitions_data = coalitions_data**(-b)\n",
    "    return -(len(data) + 1) * inv_coalitions_data + g * np.sum(inv_coalitions_data)\n",
    "\n",
    "def find_equilibrium_coalitions(turn, remaining, data, cache, g, b):\n",
    "    if turn == len(data):\n",
    "        return [frozenset({i}) for i in remaining]\n",
    "    if not remaining:\n",
    "        return []\n",
    "    if (turn, remaining) in cache:\n",
    "        return cache[(turn, remaining)]\n",
    "    if turn not in remaining:\n",
    "        return find_equilibrium_coalitions(turn + 1, remaining, data, cache, g, b)\n",
    "    best_result = find_equilibrium_coalitions(turn + 1, remaining, data, cache, g, b)\n",
    "    outside_profits = find_firms_profits(best_result, data, g, b)\n",
    "    best_profit = outside_profits[turn]\n",
    "    while True:\n",
    "        for proposal in gen_powerset_with_element(remaining - {turn}, turn):\n",
    "            curr_result = deepcopy(\n",
    "                find_equilibrium_coalitions(turn + 1, remaining - proposal, data, cache, g, b)\n",
    "            )\n",
    "            curr_result.append(proposal)\n",
    "            curr_profits = find_firms_profits(curr_result, data, g, b)\n",
    "            curr_profit = curr_profits[turn]\n",
    "            for firm in proposal:\n",
    "                if curr_profits[firm] < outside_profits[firm]:\n",
    "                    curr_profit = outside_profits[turn]\n",
    "                    break\n",
    "            if curr_profit > best_profit:\n",
    "                best_result = curr_result\n",
    "                best_profit = curr_profit\n",
    "        if best_profit <= outside_profits[turn]:\n",
    "            break\n",
    "        outside_profits = find_firms_profits(best_result, data, g, b)\n",
    "    cache[(turn, remaining)] = best_result\n",
    "    return best_result\n",
    "\n",
    "def find_equilibrium_coalitions_init(data, g, b):\n",
    "    data[::-1].sort()\n",
    "    return find_equilibrium_coalitions(0, frozenset(i for i in range(len(data))), data, {}, g, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea9347c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def instance_to_str(coalition_and_number, tries):\n",
    "    coalition, number = coalition_and_number\n",
    "    coalition_str = ' '.join(sorted(map(lambda y: ''.join(sorted(map(lambda z: str(z + 1), y))), coalition)))\n",
    "    return \"{:<15}\".format(coalition_str + ':') + \"{:.1f}\".format(number * 100 / tries) + '%'\n",
    "\n",
    "def instance_to_stat(coalition_and_stat, tries):\n",
    "    coalition, stat = coalition_and_stat\n",
    "    stat = np.array(stat) / tries\n",
    "    return coalition, stat.mean(), stat.std(ddof=1)\n",
    "\n",
    "def stat_to_str(stat):\n",
    "    coalition, mean, std = stat\n",
    "    coalition_str = ' '.join(sorted(map(lambda y: ''.join(sorted(map(lambda z: str(z + 1), y))), coalition)))\n",
    "    return \"{:<15}\".format(coalition_str + ':') +\\\n",
    "           \"{:.1f}\".format(mean * 100) + '%' +\\\n",
    "           \" +/- \" + \"{:.1f}\".format(std * 100) + '%'\n",
    "\n",
    "def find_equilibria_stat(experiments, g, b):\n",
    "    equilibria_stat = {}\n",
    "    for j, experiment in enumerate(experiments):\n",
    "        equilibria = {}\n",
    "        for data in experiment:\n",
    "            equilibrium = frozenset(find_equilibrium_coalitions_init(data, g, b))\n",
    "            equilibria[equilibrium] = equilibria.get(equilibrium, 0) + 1\n",
    "        for coalition in equilibria:\n",
    "            if coalition not in equilibria_stat:\n",
    "                equilibria_stat[coalition] = [0] * j\n",
    "        for coalition in equilibria_stat:\n",
    "            equilibria_stat[coalition].append(equilibria.get(coalition, 0))\n",
    "    return sorted(map(lambda x: instance_to_stat(x, tries), equilibria_stat.items()),\n",
    "                  key=lambda item: item[1], reverse=True)\n",
    "\n",
    "def find_stat(experiments, g, b, stat_function):\n",
    "    stats = np.empty(len(experiments))\n",
    "    for j, experiment in enumerate(experiments):\n",
    "        stat = 0\n",
    "        for data in experiment:\n",
    "            equilibrium = frozenset(find_equilibrium_coalitions_init(data, g, b))\n",
    "            stat += stat_function(equilibrium, data)\n",
    "        stats[j] = stat / len(experiment)\n",
    "    return stats.mean(), stats.std()\n",
    "\n",
    "def find_mean(coalitions, data):\n",
    "    return sum(len(coalition) for coalition in coalitions) / len(coalitions)\n",
    "\n",
    "def find_max(coalitions, data):\n",
    "    return max(len(coalition) for coalition in coalitions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573f818b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "matplotlib.use(\"pgf\")\n",
    "matplotlib.rcParams.update({\n",
    "    'pgf.texsystem': 'pdflatex',\n",
    "    'text.usetex': True,\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5144ddc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment(experiments, size, b, g):\n",
    "    xg = np.linspace(0.1, 1.0, size)\n",
    "    yg = np.empty(size)\n",
    "\n",
    "    for i in range(size):\n",
    "        yg[i], _ = find_stat(experiments, xg[i], b, find_mean)\n",
    "        \n",
    "    xb = np.linspace(0.5, 1.0, size)\n",
    "    yb = np.empty(size)\n",
    "    for i in range(size):\n",
    "        yb[i], _ = find_stat(experiments, g, xb[i], find_mean)\n",
    "    \n",
    "    return xg, yg, xb, yb\n",
    "\n",
    "def calculate_figure(repeats, tries, mean, stds, m, size, b, g):\n",
    "    X = np.empty((len(stds), 2, size))\n",
    "    Y = np.empty((len(stds), 2, size))\n",
    "    experiments = rng.standard_normal((repeats, tries, m))\n",
    "    for i in tqdm(range(len(stds))):\n",
    "        data = np.clip(experiments * stds[i] + mean, 1, None)\n",
    "        X[i][0], Y[i][0], X[i][1], Y[i][1] = experiment(data, size, b, g)\n",
    "    return X, Y\n",
    "\n",
    "def get_figure(X, Y, m, stds):\n",
    "    fig, ax = plt.subplots(len(X), 2)\n",
    "    fig.set_size_inches(w=2.7 * 2, h=2.4 * len(X))\n",
    "    for i in range(len(X)):\n",
    "        title = '$m = ' + str(m) + ', P = \\mathrm{N}(1000,' + str(stds[i])+ '^2)$'\n",
    "        ax[i][0].set_title(title)\n",
    "        ax[i][1].set_title(title)\n",
    "        ax[i][0].plot(X[i][0], Y[i][0], linestyle='-', marker='o', color='b')\n",
    "        ax[i][1].plot(X[i][1], Y[i][1], linestyle='-', marker='o', color='b')\n",
    "    fig.supylabel(\"Average coalition size\")\n",
    "    ax[-1][0].set_xlabel(\"Similarity between products ($\\gamma$)\")\n",
    "    ax[-1][1].set_xlabel(r'Simplicity of learning task ($\\beta$)')\n",
    "    plt.gcf().tight_layout()\n",
    "    plt.savefig('dep_m' + str(m) + '.pgf')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06579fdf",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "repeats = 1\n",
    "tries = 10000\n",
    "mean = 1000\n",
    "stds = [300, 600, 900]\n",
    "ms = [3, 4, 5]\n",
    "size = 10\n",
    "b = 0.9\n",
    "g = 0.8\n",
    "XX = np.empty((len(ms), len(stds), 2, size))\n",
    "YY = np.empty((len(ms), len(stds), 2, size))\n",
    "\n",
    "for i, m in enumerate(ms):\n",
    "    XX[i], YY[i] = calculate_figure(repeats, tries, mean, stds, m, size, b, g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "228b3968",
   "metadata": {},
   "outputs": [],
   "source": [
    "for X, Y, m in zip(XX, YY, ms):\n",
    "    get_figure(X, Y, m, stds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bbf583a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, 4)\n",
    "\n",
    "axbig = [fig.add_subplot(121, frameon=False), fig.add_subplot(122, frameon=False)]\n",
    "for i in range(2):\n",
    "    plt.setp(axbig[i].get_yticklabels(), visible=False)\n",
    "    plt.setp(axbig[i].get_xticklabels(), visible=False)\n",
    "    axbig[i].tick_params(axis='both', which='both', length=0)\n",
    "\n",
    "fig.set_size_inches(w=2.7 * 4, h=2.7 * 2)\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(2):\n",
    "        for k in range(2):\n",
    "            title = '$m = ' + str(ms[i]) + ', P = \\mathrm{N}(1000,' + str(stds[k])+ '^2)$'\n",
    "            ax[i][2*j + k].plot(XX[i][k][j], YY[i][k][j], linestyle='-', marker='o', color='b')\n",
    "            ax[i][2*j + k].set_title(title)\n",
    "\n",
    "size = 18\n",
    "fig.supylabel(\"Average coalition size\", fontsize=size)\n",
    "axbig[0].set_xlabel(\"Similarity between products ($\\gamma$)\", labelpad=22, fontsize=size)\n",
    "axbig[1].set_xlabel(r'Simplicity of learning task ($\\beta$)', labelpad=22, fontsize=size)\n",
    "plt.gcf().tight_layout()\n",
    "plt.savefig('dep_main.pgf')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f18c0a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
