{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "import matplotlib.lines as mlines\n",
    "\n",
    "plt.rcParams.update({'font.size': 16})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iaa_api import InterAnnotatorAgreementAPI\n",
    "from competitors.data_handling_competitors import seed_everything\n",
    "from utils import *\n",
    "from syntethic_exps import to_LA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(seed=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transformation of the values of H in A or B\n",
    "\n",
    "-  $A = \\frac{\\log{\\frac{1-\\nu}{\\nu}} + H \\log{\\frac{T_{11}}{1-T_{00}}}}{\\log{\\frac{T_{11}T_{00}}{(1-T_{11})(1-T_{00})}}}$\n",
    "\n",
    "\n",
    "- $ B = \\frac{\\log{\\frac{\\nu}{1-\\nu}} + H \\log{\\frac{T_{00}}{1-T_{11}}}}{\\log{\\frac{T_{11}T_{00}}{(1-T_{11})(1-T_{00})}}}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def H_to_AB(vu, T, H:float, A: bool = True, slack: float = 1e-4):\n",
    "    \"\"\"\n",
    "    We obtain the values of A and B from T, H and vu as defined in the report\n",
    "    vu: distribution of the classes\n",
    "    T: T matrix\n",
    "    H: number of annotators (we assume it odd)\n",
    "    A: if True it returns the A matrix, else the B matrix\n",
    "    slack: is a slack variable used to avoid to compute log(0) which goes to infinity\n",
    "    \"\"\"\n",
    "    nu_ratio = vu/(1-vu)\n",
    "    delta_0 = 1-T[0][0]\n",
    "    delta_1 = 1-T[1][1]\n",
    "    if T[0][0] == T[1][1] and A:\n",
    "        return ((-np.log(nu_ratio+slack)) / (2*np.log((T[0][0]/(delta_0))+slack))) + (H/2)\n",
    "    if T[0][0] == T[1][1] and not A:\n",
    "        return ((np.log(nu_ratio+slack)) / (2*np.log((T[0][0]/(delta_0))+slack))) + (H/2)\n",
    "    if T[0][0] != T[1][1] and A:\n",
    "        return (np.log((1/nu_ratio)+slack)+ H*np.log((T[1][1]/delta_0)+slack)) / (np.log(((T[0][0]*T[1][1])/(delta_0*delta_1))+slack))\n",
    "    if T[0][0] != T[1][1] and not A:\n",
    "        return (np.log(nu_ratio+slack)+ H*np.log((T[0][0]/delta_1)+slack)) / (np.log(((T[0][0]*T[1][1])/(delta_0*delta_1))+slack))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Majority voting\n",
    "\n",
    "$T_{cc}^{MV} = \\sum_{i=\\lceil{\\frac{H}{2}}\\rceil}^{H} \\binom{H}{i} T_{cc}^{i}\\,(1-T_{cc})^{H-i}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def T_majority(T, H:int):\n",
    "    \"\"\"\n",
    "    Simple computation of MV as in the formula\n",
    "    \"\"\"\n",
    "    T_mv = np.zeros_like(T)\n",
    "    for k in range(int(((H+1)/2)),H+1):\n",
    "        T_mv = T_mv + scipy.special.binom(H,k) * np.power(T,k)* \\\n",
    "        np.power(1-T,H-k)\n",
    "    return T_mv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Oracle Maximum A Posteriori (MAP)\n",
    "\n",
    "$T_{cc}^{MAP} = \\sum_{i=\\lceil{A_c}\\rceil}^{H} \\binom{H}{i} T_{cc}^{i}\\,(1-T_{cc})^{H-i}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def T_map(T, H:int, vu:float, debug:bool=False):\n",
    "  \"\"\"\n",
    "    Simple computation of oracle MAP as in the formula\n",
    "    \"\"\"\n",
    "  T_map = np.zeros_like(T)\n",
    "  A = H_to_AB(vu, T, H, A=True)\n",
    "  B = H_to_AB(vu, T, H, A=False)\n",
    "  if debug:\n",
    "    print(\"A: \", int(np.ceil(A)))\n",
    "    print(\"B: \", int(np.ceil(B)))\n",
    "    print(\"(H+1)/2: \", (H+1)/2)\n",
    "\n",
    "  for k in range(int(np.ceil(A)),H+1):\n",
    "      T_map[0][0] += scipy.special.binom(H,k) * \\\n",
    "      np.power(T[0][0],k)*  np.power(1-T[0][0],H-k)\n",
    "\n",
    "    \n",
    "\n",
    "  for k in range(int(np.ceil(B)),H+1):\n",
    "      T_map[1][1] += scipy.special.binom(H,k) * \\\n",
    "      np.power(T[1][1],k)*  np.power(1-T[1][1],H-k)\n",
    " \n",
    "  T_map[0][1] = 1-T_map[0][0]\n",
    "  T_map[1][0] = 1-T_map[1][1]\n",
    "  return T_map"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Method to chech when MAP is better than majority and when they are equal\n",
    "\n",
    "If $T_{cc}^{MV} \\left( \\nu \\,\\,\\,\\,\\,\\, 1- \\nu \\right) < T_{cc}^{MAP} \\left( \\nu \\,\\,\\,\\,\\,\\, 1- \\nu \\right) \\Rightarrow$ MAP is better than MV\n",
    "\n",
    "If $T_{cc}^{MV} \\left( \\nu \\,\\,\\,\\,\\,\\, 1- \\nu \\right) > T_{cc}^{MAP} \\left( \\nu \\,\\,\\,\\,\\,\\, 1- \\nu \\right) \\Rightarrow$ MV is better than MAP\n",
    "\n",
    "If $\\lVert T_{cc}^{MV} \\left( \\nu \\,\\,\\,\\,\\,\\, 1- \\nu \\right) - T_{cc}^{MAP} \\left( \\nu \\,\\,\\,\\,\\,\\, 1- \\nu \\right) \\rVert < 10^{-2} \\Rightarrow$ MAP = MV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MAP_better_majority(T_majority, T_map, vu:float,\n",
    "                        debug:bool=False, required_difference=False):\n",
    "  \n",
    "  if debug:\n",
    "    print( \"Difference Majority-MAP: \", -np.dot(np.diag(T_majority), np.array([vu, 1-vu])) + np.dot(np.diag(T_map), np.array([vu, 1-vu])))\n",
    "    print(\"Majority: \", np.dot(np.diag(T_majority), np.array([vu, 1-vu])))\n",
    "    print(\"MAP: \", np.dot(np.diag(T_map), np.array([vu, 1-vu])))\n",
    "    print( \"Check 1. T_00 MAP > T_00 MV: \",  T_map[0][0] - T_majority[0][0] > (1-vu)/(vu)*(-T_map[1][1] + T_majority[1][1])  )\n",
    "    print( \"Check 2. T_11 MAP > T_11 MV \",  -T_map[0][0] + T_majority[0][0] < (1-vu)/(vu)*(T_map[1][1] - T_majority[1][1])  )\n",
    "    print( \" 1-nu/nu\", (1-vu)/vu)\n",
    "  if not required_difference:\n",
    "    if np.dot(np.diag(T_majority), np.array([vu, 1-vu])) < np.dot(np.diag(T_map), np.array([vu, 1-vu])):\n",
    "      return \"MAP > MV\"\n",
    "    elif np.linalg.norm(np.dot(np.diag(T_majority), np.array([vu, 1-vu])) - np.dot(np.diag(T_map), np.array([vu, 1-vu]))) < 1e-2:\n",
    "      return \"MAP = MV\"\n",
    "    else:\n",
    "      return \"MAP < MV\"\n",
    "  else:\n",
    "    return -np.dot(np.diag(T_majority), np.array([vu, 1-vu])) + np.dot(np.diag(T_map), np.array([vu, 1-vu]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_differences_mv_map(T, vu:float, max_H:int=30):\n",
    "    \"\"\"On the x axis there is the number of annotators while on the y axis the difference\n",
    "    between MAP and MV. It is possible to notice how with H high, they are practically the same.\n",
    "    T: T matrix\n",
    "    H: number of annotators (we assume it odd)\n",
    "    max_H: maximum number of annotators\n",
    "    debug: print additional info\n",
    "    focus: if True, the plot is focused on the first values of H\n",
    "    \"\"\"\n",
    "    all_H_values = [x for x in range(1, max_H +1) if x % 2 != 0]\n",
    "    differences = []\n",
    "\n",
    "    for H in all_H_values:\n",
    "        T_maj = T_majority(T,H)\n",
    "        T_map_matrix = T_map(T, H, vu)\n",
    "        differences.append(-np.dot(np.diag(T_maj), np.array([vu, 1-vu]))  + np.dot(np.diag(T_map_matrix), np.array([vu, 1-vu])))\n",
    "    return all_H_values, differences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_multiple_differences(T_matrices: list, vu_values: list, savefig:bool=False):\n",
    "    \"\"\"Given a collection of T matrices and a collection of vu values it allows \n",
    "    to compute the difference from the method compute_differences_mv_map\n",
    "    and then it generates the plots.\n",
    "    \"\"\"\n",
    "    res = {}\n",
    "    for T_single, vu_single in zip(T_matrices, vu_values):\n",
    "        res[T_single[0][0]] = compute_differences_mv_map(T_single, vu_single, 100)[1]\n",
    "    x_values = compute_differences_mv_map(T_matrices[0], vu_values[0], 100)[0]\n",
    "    all_colors = obtain_color(len(T_matrices))\n",
    "    all_markers = obtain_markers(len(T_matrices))\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(8, 6))\n",
    "    for i, (T_cc,res_val) in enumerate(res.items()):\n",
    "        label =  r'$T_{{cc}} = {:.2f}, \\, \\nu = {:.1f},$'.format(T_cc, vu_values[i])\n",
    "        plt.plot(x_values, res_val, color=all_colors[i],\n",
    "                  marker=all_markers[i], label=label)\n",
    "    plt.xlim(0,x_values[-1]+1)\n",
    "    plt.legend()   \n",
    "    plt.xlabel(r'$H$')\n",
    "    plt.ylabel(r'$\\text{diag}(T^{MAP})\\left(\\nu \\,\\,\\, ,  1-\\nu\\right) - \\text{diag}(T^{MV})\\left(\\nu \\,\\,\\, , 1-\\nu\\right)$')\n",
    "    plt.tight_layout()\n",
    "    if savefig:\n",
    "        plt.savefig('results/infinity.pdf',dpi=600, format='pdf')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T_matrices = [np.array([[0.53, 0.47], [0.47, 0.53]]), np.array([[0.7, 0.3], [0.3, 0.7]]),\n",
    "               np.array([[0.95, 0.05], [0.05, 0.95]]),\n",
    "              np.array([[0.55, 0.45], [0.45, 0.55]])]\n",
    "vu_values = [0.3, 0.9, 0.5, 0.1]\n",
    "\n",
    "plot_multiple_differences(T_matrices=T_matrices, vu_values=vu_values, savefig=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main Theorem Plots with 2-coin case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_conditions_two_coin(H:int, vu_values = [x/10 for x in range(1,10)],\n",
    "                    obtain_diff:bool=False, rel_improvement = 1e-3):\n",
    "    \"\"\"\n",
    "    The method iterates over vu values and generates T matrices with an improvement defined by the improvement variable.\n",
    "    After each T matrix is generated are computed MAP and MV and is checked when MAP is better than MV.\n",
    "    A plot from this check is generated.\n",
    "    It is also possible to plot the difference between MAP and MV.\n",
    "    \"\"\"\n",
    "    assert len(vu_values) % 2== 0 or len(vu_values)%3==0, \"Lenght of the array must be multiple of 2 or 3.\"\n",
    "    index,row = 0,0\n",
    "    if len(vu_values) %2 != 0:\n",
    "        divisor = 3\n",
    "        fig, ax = plt.subplots(int(len(vu_values)/divisor), divisor, figsize=(10,10)) \n",
    "    else:\n",
    "        divisor = 2\n",
    "        fig, ax = plt.subplots(int(len(vu_values)/divisor), divisor, figsize=(10,10))\n",
    "\n",
    "    for vu in tqdm(vu_values):\n",
    "        final = {}\n",
    "        labels = {}\n",
    "        color = None\n",
    "        for t in np.arange(0.5, 1-rel_improvement, rel_improvement):\n",
    "            for s in np.arange(1-rel_improvement, 0.5, -rel_improvement):\n",
    "                T = np.array([[t, 1-t], [1-s, s]]) #[0.3,0.7]])\n",
    "                T_matrix_map = T_map(T=T, H=H, vu=vu)\n",
    "                T_matrix_mv = T_majority(T=T, H=H)\n",
    "                text_to_color = MAP_better_majority(T_majority=T_matrix_mv, T_map=T_matrix_map,\n",
    "                                                    vu=vu)\n",
    "                \n",
    "                ratio_1 = my_delta(T[1][1], T[0][0])/my_delta(T[0][0],T[1][1])\n",
    "                ratio_0 = my_delta(T[0][0], T[1][1])/my_delta(T[1][1],T[0][0])\n",
    "                a = (1/(ratio_1**(H/2)*np.sqrt(my_rho(T))+1)<(1-vu)) and ((1-vu) < 1/(ratio_1**(H/2)* (1/ np.sqrt(my_rho(T)))+1))\n",
    "                b = (1/(ratio_0**(H/2)*np.sqrt(my_rho(T))+1)<vu) and (vu < 1/(ratio_0**(H/2)* (1/ np.sqrt(my_rho(T)))+1))\n",
    "                if text_to_color == \"MAP > MV\": #If MV is less then MAP -> color is red\n",
    "                    color = '#e41a1c'\n",
    "                if text_to_color == \"MAP < MV\": #If MV is greater then MAP -> color is green\n",
    "                    color = '#377eb8'\n",
    "                if a and b: #MAP equal to MV\n",
    "                    color = '#377eb8'\n",
    "                final[(t,s)] = color\n",
    "                labels[(t,s)] = np.dot(np.diag(T_matrix_map), np.array([vu, 1-vu])) - np.dot(np.diag(T_matrix_mv), np.array([vu, 1-vu]))\n",
    "        t_00 = [key[0] for key in final.keys()]\n",
    "        t_11 = [key[1] for key in final.keys()]\n",
    "        if obtain_diff:\n",
    "            if len(vu_values) == 2:\n",
    "                ax[index].scatter(t_00, labels.values(), c=final.values())\n",
    "            else:\n",
    "                ax[row][index].scatter(t_00, labels.values(), c=final.values())\n",
    "        else:\n",
    "            if len(vu_values) == 2:\n",
    "                ax[index].scatter(t_00, t_11, c=final.values())\n",
    "                ax[index].set_xlabel(r'$T_{00}$')\n",
    "                ax[index].set_ylabel(r'$T_{11}$')\n",
    "                ax[index].set_title(r'$\\nu=$' + str(vu))\n",
    "            else:\n",
    "                ax[row][index].scatter(t_00, t_11, c=final.values())\n",
    "                ax[row][index].set_xlabel(r'$T_{00}$')\n",
    "                ax[row][index].set_ylabel(r'$T_{11}$')\n",
    "                ax[row][index].set_title(r'$\\nu=$' + str(vu))\n",
    "        if (index+1) % divisor == 0:\n",
    "            index = 0\n",
    "            row +=1 \n",
    "        else:\n",
    "            index +=1\n",
    "    custom_labels = ['MAP > MV', 'MAP = MV']\n",
    "    handles = [plt.Line2D([0], [0], color='#e41a1c', lw=2),\n",
    "               plt.Line2D([0], [0], color='#377eb8', lw=2),]\n",
    "    fig.legend(handles, custom_labels, loc='lower right', bbox_to_anchor=(0.81, 0.63), fontsize='14') \n",
    "    fig.tight_layout() \n",
    "    plt.savefig('results/plot_two_coin.pdf', format='pdf')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "H = 3\n",
    "check_conditions_two_coin(H=H, obtain_diff=False, vu_values = [x/10 for x in range(1,10)],\n",
    "                          rel_improvement = 1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_heatmap_two_coin(H:int, vu_values = [x/10 for x in range(1,10)],\n",
    "                          rel_improvement = 1e-3):\n",
    "    \"\"\"\n",
    "    The method iterates over vu values and generates T matrices with an improvement defined by the improvement variable.\n",
    "    After each T matrix is generated are computed MAP and MV and is checked the difference between MAP and MV.\n",
    "    A heatmap for each value of vu is generated\n",
    "    \"\"\"\n",
    "    assert len(vu_values) % 2== 0 or len(vu_values)%3==0, \"Lenght of the array must be multiple of 2 or 3.\"\n",
    "    index,row = 0,0\n",
    "    if len(vu_values) %2 != 0:\n",
    "        divisor = 3\n",
    "        fig, ax = plt.subplots(int(len(vu_values)/divisor), divisor, figsize=(10,10))\n",
    "    else:\n",
    "        divisor = 2\n",
    "        fig, ax = plt.subplots(int(len(vu_values)/divisor), divisor, figsize=(10,10)) \n",
    "\n",
    "    for vu in tqdm(vu_values):\n",
    "        final = {}\n",
    "        for t in np.arange(0.5, 1-rel_improvement, rel_improvement):\n",
    "            for s in np.arange(1-rel_improvement, 0.5, -rel_improvement):\n",
    "                T = np.array([[t, 1-t], [1-s, s]])\n",
    "                T_matrix_map = T_map(T=T, H=H, vu=vu)\n",
    "                T_matrix_mv = T_majority(T=T, H=H)\n",
    "                difference_map_mv = MAP_better_majority(T_majority=T_matrix_mv, T_map=T_matrix_map,\n",
    "                                                    vu=vu, required_difference=True)\n",
    "                final[(t,s)] = difference_map_mv\n",
    "        t_00 = [key[0] for key in final.keys()]\n",
    "        t_11 = [key[1] for key in final.keys()]\n",
    "        if len(vu_values) == 2:\n",
    "            im = ax[index].scatter(t_00, t_11, c=list(final.values()), cmap='YlOrBr')\n",
    "            ax[index].set_xlabel(r'$T_{00}$')\n",
    "            ax[index].set_ylabel(r'$T_{11}$')\n",
    "            ax[index].set_title(r'$\\nu=$' + str(vu))\n",
    "        else:\n",
    "            im = ax[row][index].scatter(t_00, t_11, c=list(final.values()), cmap='YlOrBr')\n",
    "            ax[row][index].set_xlabel(r'$T_{00}$')\n",
    "            ax[row][index].set_ylabel(r'$T_{11}$')\n",
    "            ax[row][index].set_title(r'$\\nu=$' + str(vu))\n",
    "        if (index+1) % divisor == 0:\n",
    "            index = 0\n",
    "            row +=1 \n",
    "        else:\n",
    "            index +=1\n",
    "    fig.colorbar(im, ax=ax, location='bottom', orientation = 'horizontal', anchor = (0.5, -.9))\n",
    "    fig.tight_layout()\n",
    "    plt.tight_layout() \n",
    "    plt.savefig('results/plot_heatmap_two_coin.pdf', format='pdf')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_heatmap_two_coin(H=3, rel_improvement=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main Theorem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_condistions_main_theorem(H:int, vu_values = [x/10 for x in range(1,10)], rel_improvement = 1e-2,\n",
    "                                   estimate:str=None):\n",
    "    \"\"\"\n",
    "    An array of vu values is iterated. For each vu value:\n",
    "    Is generated a T matrix with T00 = T11 and is computed if for each T matrix is better MV or MAP.\n",
    "    There are also some \"region colors\" which help to clarify why in a situation a method is better wrt. another one.\n",
    "    H: number of annotators (we assume it odd)\n",
    "    vu_values: array of distribution of samples for each class.\n",
    "    len(vu_values) must be multiple of 2 or 3.\n",
    "    \"\"\"\n",
    "    plt_name = 'standard_MV_MAP'\n",
    "    final = {}\n",
    "    for vu in tqdm(vu_values):\n",
    "        color = None\n",
    "        for t in np.arange(0.5, 1-rel_improvement, rel_improvement):\n",
    "            T = np.array([[t, 1-t], [1-t, t]])\n",
    "            if estimate is not None:\n",
    "                true_labels = generate_true_labels(C=2, N=100, D=[vu, 1-vu])\n",
    "                data = generate_annotations(true_labels, T, H=H, obtain_list=True)\n",
    "                if estimate == 'iaa':\n",
    "                    modified_data = np.array(data, dtype=object)\n",
    "                    iaa = InterAnnotatorAgreementAPI(modified_data)\n",
    "                    iaa._build_t_matrix()\n",
    "                    T = iaa._t_hat\n",
    "                    vu_est = iaa._label_distribution[0]\n",
    "                    plt_name = 'iaa_MV_MAP'\n",
    "                if estimate == 'iwmv':\n",
    "                    e2wl, w2el, label_set = to_LA(data)\n",
    "                    _, _, T_matrix = iwmv(e2wl, w2el, label_set, T_required=True)\n",
    "                    T = T_matrix\n",
    "                    plt_name = 'iwmv_MV_MAP'\n",
    "            try:\n",
    "                if estimate == 'iaa':\n",
    "                    T_matrix_map = T_map(T=T, H=H, vu=vu_est, debug=False)\n",
    "                else:\n",
    "                    T_matrix_map = T_map(T=T, H=H, vu=vu, debug=False)\n",
    "                T_matrix_mv = T_majority(T=T, H=H)\n",
    "                text_to_color = MAP_better_majority(T_majority=T_matrix_mv, T_map=T_matrix_map,\n",
    "                                                    vu=vu)\n",
    "                \n",
    "                if text_to_color == \"MAP > MV\": #If MV is less then MAP -> color is red\n",
    "                    color = '#e41a1c'\n",
    "                if text_to_color == \"MAP < MV\": #If MV is greater then MAP -> color is green\n",
    "                    color = '#377eb8'\n",
    "                if text_to_color == \"MAP = MV\": #If MV is equal to MAP -> color is yellow\n",
    "                    color = '#377eb8'\n",
    "                final[(t,vu)] = color\n",
    "            except ValueError:\n",
    "                print('+1')\n",
    "                continue\n",
    "    t_values = [1 - x[0] for x in final.keys()]\n",
    "    vu_all = [x[1] for x in final.keys()]\n",
    "    if estimate is None:\n",
    "        color_map = {'#e41a1c': 'oMAP>MV', '#377eb8': 'oMAP=MV'}\n",
    "    if estimate == 'iaa':\n",
    "        color_map = {'#e41a1c': r'eMAP$_{IAA} > MV$', '#377eb8': r'eMAP$_{IAA} = MV$'}\n",
    "    if estimate == 'iwmv':\n",
    "        color_map = {'#e41a1c': r'eMAP$_{IWMV} > MV$', '#377eb8': r'eMAP$_{IWMV} = MV$'}\n",
    "    fig, ax = plt.subplots(figsize=(8, 6))\n",
    "    plt.scatter(t_values, vu_all, c=final.values())\n",
    "\n",
    "    legend_artists = []\n",
    "    legend_labels = []\n",
    "    for color, label in color_map.items():\n",
    "        dummy_line = mlines.Line2D([], [], color=color, marker='o', linestyle='None', markersize=10, label=label)\n",
    "        legend_artists.append(dummy_line)\n",
    "        legend_labels.append(label)\n",
    "        \n",
    "    plt.xlabel(r'$1-T_{cc}$')\n",
    "    plt.ylabel(r'$\\nu$')\n",
    "    #plt.title(r'$\\nu=$' + str(vu))\n",
    "    plt.legend(handles=legend_artists, labels=legend_labels, loc='best')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'results/{plt_name}.pdf', dpi=300, format='pdf')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimates = [None, 'iaa', 'iwmv']\n",
    "for est in estimates:\n",
    "    check_condistions_main_theorem(H = 3, vu_values=np.linspace(start=0.0001, stop=0.999, num=50), rel_improvement=1e-3,\n",
    "                               estimate=est)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_heatmap_difference(H:int, vu_values = [x/10 for x in range(1,10)],\n",
    "                            rel_improvement = 1e-3):\n",
    "    final = {}\n",
    "    for vu in vu_values:\n",
    "        for t in np.arange(0.5, 1-rel_improvement, rel_improvement):\n",
    "            T = np.array([[t, 1-t], [1-t, t]])\n",
    "            T_matrix_map = T_map(T=T, H=H, vu=vu, debug=False)\n",
    "            T_matrix_mv = T_majority(T=T, H=H)\n",
    "            differce_map_mv = MAP_better_majority(T_majority=T_matrix_mv, T_map=T_matrix_map,\n",
    "                                                vu=vu, required_difference=True)\n",
    "            final[(t,vu)] = differce_map_mv\n",
    "\n",
    "    t_values = [1 - x[0] for x in final.keys()]\n",
    "    vu_all = [x[1] for x in final.keys()]\n",
    "    fig, ax = plt.subplots(figsize=(8, 6))\n",
    "    im = plt.scatter(t_values, vu_all, c=list(final.values()), cmap='YlOrBr')\n",
    "    cbar = fig.colorbar(im, ax=ax)\n",
    "\n",
    "    plt.xlabel(r'$1-T_{cc}$')\n",
    "    plt.ylabel(r'$\\nu$')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig('results/plot_heatmap.pdf', format='pdf')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vu_values = np.arange(0., 1., 0.01)\n",
    "plot_heatmap_difference(3, vu_values=vu_values,rel_improvement = 1e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
