{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.cluster.hierarchy import linkage, dendrogram, fcluster\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "import time\n",
    "from functools import partial\n",
    "import os\n",
    "import random\n",
    "from collections import defaultdict\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## File IO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_preferences(file_path):\n",
    "    \"\"\"Read preferences from a file. Each preference list is in a separate line, prefernces are\n",
    "    separated using , and they might come after :. Lines without , are skipped\"\"\"\n",
    "    \n",
    "    with open(file_path, 'r') as f:\n",
    "        lines = [line.strip() for line in f]\n",
    "\n",
    "    count = 0\n",
    "    preferences = list()\n",
    "    for l in lines:\n",
    "        #print (count, \": \", l)\n",
    "    \n",
    "        if ',' in l: \n",
    "            num_string = l.rsplit(\":\", 1)[-1]\n",
    "            preferences.append([int(n) for n in num_string.split(\",\")])\n",
    "            count += 1\n",
    "            \n",
    "    return preferences"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## rank distance & kendal-tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def position(top):\n",
    "  return {t:i for i, t in enumerate(top, start = 1)}\n",
    "\n",
    "def num_misorders(top_1, top_2, n):\n",
    "  \"\"\"Calculates kendal-tau distance (integral and fractional) for top-k rankings in O(k^2)\"\"\"\n",
    "  count = 0\n",
    "    \n",
    "  if len(top_1) != len(top_2): \n",
    "    print(\"inconsistent k\")\n",
    "    return count\n",
    "      \n",
    "  k = len(top_1)\n",
    "  pos_1 = position(top_1)\n",
    "  pos_2 = position(top_2)\n",
    "  s = pos_1.keys() & pos_2.keys()\n",
    "  l = len(s)\n",
    "  #print (\"common keys: \", s)\n",
    "  \n",
    "  # case i \\in s, j \\in s, O(s^2)\n",
    "  for i in s:\n",
    "    for j in s: \n",
    "      if (pos_1[i] - pos_1[j]) * (pos_2[i] - pos_2[j]) < 0:\n",
    "        count += 1\n",
    "\n",
    "  # since double counting need to divide by 2.\n",
    "  count /= 2\n",
    "\n",
    "  #print(\"total distance for s x s: \", total_distance)\n",
    "   \n",
    "  # case i \\in top_1\\s, j \\in s, O(k * s)\n",
    "  for i in top_1:\n",
    "    for j in s:\n",
    "      if i in s: \n",
    "        continue\n",
    "      if pos_1[i] < pos_1[j]:\n",
    "        count += 1 \n",
    "  #print(\"total distance updated after adding top_1\\\\s x s: \", total_distance)\n",
    "\n",
    "  # case i \\in top_2\\s, j \\in s, O(k * s)\n",
    "  for i in top_2:\n",
    "    for j in s:\n",
    "      if i in s:\n",
    "        continue\n",
    "      if pos_2[i] < pos_2[j]:\n",
    "        count += 1\n",
    "\n",
    "  # case i \\in top_2\\s, j \\in top_1\\s or vice-versa, O(1)\n",
    "  count += (k-l) * (k-l)\n",
    "  #print(\"total distance updated after adding top_2\\\\s x top_1\\\\s or vice-versa: \", total_distance)\n",
    "\n",
    "  return count\n",
    "\n",
    "def num_incomparable_diff(top_1, top_2, n):\n",
    "  \"\"\"computes poition and then after thay it is just O(1).\"\"\"\n",
    "  count = 0\n",
    "    \n",
    "  if len(top_1) != len(top_2): \n",
    "    print(\"inconsistent k\")\n",
    "    return count\n",
    "      \n",
    "  k = len(top_1)\n",
    "  pos_1 = position(top_1)\n",
    "  pos_2 = position(top_2)\n",
    "  s = pos_1.keys() & pos_2.keys()\n",
    "  l = len(s)\n",
    "    \n",
    "  count = 0\n",
    "  #print(\"total distance updated after adding top_2\\\\s x s: \", total_distance)\n",
    "\n",
    "  # case i \\in top_1\\s, j \\in top_1\\s and i \\in top_2\\s , j \\in top_2\\s, O(1) \n",
    "  count += (k-l) * (k-l-1) \n",
    "  #print(\"total distance updated after adding top_1\\\\s x top_1\\\\s and top_2\\\\s x top_2\\\\s: \", total_distance)\n",
    "\n",
    "  # case i \\in bag\\(top_1 U top_2), j \\in top_1\\s or j \\in top_2\\s, O(1)\n",
    "  count += (k-l) * (n-2*k+l) * 2\n",
    "  #print(\"total distance updated after adding bag\\\\(top_1 U top_2) x [top_1\\\\s or top_2\\\\s]: \", total_distance)\n",
    "\n",
    "  return count\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def disagreements (rankings, num_elements):\n",
    "    num_rankings = len(rankings)\n",
    "    misorders = np.zeros((num_rankings,num_rankings))\n",
    "    incomparables = np.zeros((num_rankings,num_rankings))\n",
    "    for i in range(num_rankings):\n",
    "        if i % 100 == 0: print(i)\n",
    "        for j in range(i + 1, num_rankings):\n",
    "            #print(\"ranking a \", rankings[i])\n",
    "            #print(\"ranking b \", rankings[j])\n",
    "            misorders[i][j] = misorders[j][i] = num_misorders(rankings[i], rankings[j], num_elements)\n",
    "            #print(\"num_misorders between {} and {} is {} \".format(rankings[i], rankings[j], misorders_ij))\n",
    "            incomparables[i][j] = incomparables[j][i] = num_incomparable_diff(rankings[i], rankings[j], num_elements)\n",
    "            #print(\"num_incomparables between {} and {} is {} \".format(rankings[i], rankings[j], incomparables_ij))\n",
    "       \n",
    "    return misorders, incomparables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "def exp_kendal_tau(top_1, top_2, n, p, beta):\n",
    "    misorders = num_misorders(top_1, top_2, n)\n",
    "    incomparables = num_incomparable_diff(top_1, top_2, n)\n",
    "    return misorders + .001 * incomparables\n",
    "    #return math.exp(-beta * (misorders + incomparables * p)) \n",
    "\n",
    "\n",
    "def cluster_preferences(preferences, misorders, incomparables, n , p, beta ):\n",
    "\n",
    "    snapshot_time = time.time()\n",
    "    #param_kendal_tau = partial(exp_kendal_tau, n, p, beta)\n",
    "    #dist_array = pdist(preferences, metric=param_kendal_tau)  # 1D condensed form\n",
    "    #dist_matrix = squareform(dist_array)       # Convert to square matrix\n",
    "    #dist_matrix = misorders + 0.001 * incomparables\n",
    "    dist_matrix = np.exp(-beta * (misorders + p * incomparables))\n",
    "    # print(\"dist_matrix constructed in \", time.time() - snapshot_time)\n",
    "    # print(\"dist_matrix[:10, :10]: \", dist_matrix[:10,:10])\n",
    "    # print(\"misorders[:10, :10]: \", misorders[:10,:10])\n",
    "    # print(\"incomparables[:10, :10]: \", incomparables[:10,:10])\n",
    "    snapshot_time = time.time()\n",
    "\n",
    "\n",
    "    # Perform hierarchical clustering\n",
    "    Z = linkage(dist_matrix, method='average')  # or 'complete', 'single'\n",
    "    print(\"linkage completed in \", time.time() - snapshot_time)\n",
    "\n",
    "    return Z\n",
    "    \n",
    "# Get cluster assignments (e.g., 10 clusters)\n",
    "#k = 4\n",
    "#labels = fcluster(Z, t=k, criterion='maxclust')\n",
    "\n",
    "#print(\"Cluster labels: \", labels)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DYCHIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def printDP(L):\n",
    "  n=len(L)\n",
    "  for i in range(n):\n",
    "    print(\"array number\", i)\n",
    "    print(L[i])\n",
    "\n",
    "\n",
    "\n",
    "def DypChip(A, S, sigma, n, k, p , beta, w):\n",
    "  # the following array stores values of pi_S(i,j,s) as defined in the paper\n",
    "  m=len(A)\n",
    "  ell=len(S)\n",
    "  pr_perp=[0]*m\n",
    " # print(\"We just entered DypChip with: \")\n",
    " # print(\"profile:\", S)\n",
    " # print(\"parameters: size of profile\",ell,\", size of assortment:\",m)\n",
    "  DPtable=np.zeros((k,m,k+1))\n",
    "  up=[sigma[i] for i in S]\n",
    " # print(\"up:\",up)\n",
    "  Aperp=[a for a in A if a not in sigma]\n",
    " # print(\"Aperp\",Aperp)\n",
    "  Aperpsize=len(Aperp)\n",
    "  #initializing the DP table\n",
    "  C=0\n",
    "  for j in range(k-ell):\n",
    "    C_j=1/(n-k-j)\n",
    "    for count in range(j):\n",
    "      C_j=C_j*(1-(Aperpsize/(n-k-count)))\n",
    "    C=C+C_j\n",
    "    for i in range(m):\n",
    "      if A[i] in Aperp:\n",
    "       # print(\"setting DP table at i,j,k-ell to C:\",i,j,k-ell,C_j)\n",
    "        DPtable[k-ell-1][i][j]=C_j\n",
    "      else:\n",
    "        DPtable[k-ell-1][i][j]=0\n",
    "    for i in range(m):\n",
    "      if  A[i] in Aperp:\n",
    "        DPtable[k-ell-1][i][k]=(1/Aperpsize)-C\n",
    "        pr_perp[i]=(1/Aperpsize)-C\n",
    "  # we now iterate over s=k-ell+1 ..k\n",
    "  ind=ell-1;\n",
    "  for s in range(k-ell,k):\n",
    "  #  print(\"index of inserted element in S\", s, \"iteration\", ind)\n",
    "    a=sigma[S[ind]]\n",
    " #   print(\"****** s *****\", s,\"index\", ind, \"inserting\",a)\n",
    "    # we now calculate RIM probablities for a, since a is ranked s in the profile S, the possible number of inversions for it can be any number in [1...ind]\n",
    "    RIM_a=[0]*(s+1)\n",
    "    for i in range(s+1):\n",
    "      RIM_a[i]=math.exp(-beta*w[sigma.index(a)]*i)\n",
    "    T=sum(RIM_a)\n",
    "    RIM_a=[i/T for i in RIM_a]\n",
    "  #  print(\"RIM:\",RIM_a)\n",
    "    # case 1\n",
    "    if a not in A:\n",
    "    #  print(\"case 1, inserting\", a)\n",
    "      for j in range(s+1):\n",
    "        bef_j=sum(RIM_a[0:j])\n",
    "        af_j=sum(RIM_a[j+1:s+1])\n",
    "      #  print(\"bf\",RIM_a[0:j],bef_j)\n",
    "     #   print(\"af\",RIM_a[j+1:s+1],af_j)\n",
    "    #    print(\"check sum\", bef_j+af_j+RIM_a[j])\n",
    "        for i in range(m):\n",
    "       #   print(\"setting\", iter,\":array\", i,\":element\", j,\":loc\")\n",
    "     #     print(\"i\",i,\": af_j*DPtable[s-1][i][j]+ bef_j*DPtable[s-1][i][j-1]\",af_j,\"*\",DPtable[s-1][i][j],\"+\",bef_j,\"*\",DPtable[s-1][i][j-1])\n",
    "          DPtable[s][i][j]=af_j*DPtable[s-1][i][j]+ bef_j*DPtable[s-1][i][j-1]\n",
    "      #set last column\n",
    "\n",
    "          if  A[i] in Aperp:\n",
    "            DPtable[s][i][k]=DPtable[s-1][i][k]\n",
    "\n",
    "    # case 2\n",
    "    if a in A:\n",
    "      loc=A.index(a)\n",
    "  #    print(\"case 2, inserting\",a)\n",
    "      for j in range(s+1):\n",
    "        for i in range(m):\n",
    "          if i!=loc:\n",
    "           # print(i)\n",
    "           # print(\"pr_af\",j,sum(RIM_a[j+1:s+1]))\n",
    "            DPtable[s][i][j]=DPtable[s-1][i][j]*sum(RIM_a[j+1:s+1])\n",
    "       #     print(\"setting i\",i, \"and j\",j, \"elemt\", A[i],\"to\", DPtable[s-1][i][j]*sum(RIM_a[j+1:s+1]),\":\",DPtable[s-1][i][j],\"*\",sum(RIM_a[j+1:s+1]) )\n",
    "        C=FindSumSubarray(DPtable,s,j,m)\n",
    "    #    print('sum,',j,\"is\",C)\n",
    "   #     print(\"loc\",loc)\n",
    "        DPtable[s][loc][j]=RIM_a[j]*C\n",
    "    #    print(\"just set the value for index\",ind,\"a\",a,\"loc\",loc,\"at j\",j,\"to\",RIM_a[j],\"*\",C,\"=\",RIM_a[j]*C)\n",
    "\n",
    "    ind=ind-1\n",
    "\n",
    " # printDP(DPtable)\n",
    "  finalDP=DPtable[k-1][0:m][0:k+1]\n",
    "  # p_perp\n",
    "  Atop=[i for i in A if i in up]\n",
    "  if(len(Atop)!=0):\n",
    "    pr_perp=[0]*m\n",
    "  return finalDP, pr_perp\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def DypChipII(A, sigma, n, k, p , beta, w):\n",
    "  # the following array stores values of pi_S(i,j,s) as defined in the paper\n",
    "  m=len(A)\n",
    "  pr_perp=[0]*m\n",
    " # print(\"We just entered DypChip II \")\n",
    "  DPtable=np.zeros((k,m,k+1))\n",
    " # print(\"up:\",up)\n",
    "  ind=k-1\n",
    "  first_sum=1\n",
    "  #printDP(DPtable)\n",
    "  # we now iterate over s=k-ell+1 ..k\n",
    "  for s in range(k):\n",
    "  #  print(\"index of inserted element in S\", s, \"iteration\", ind)\n",
    "    a=sigma[ind]\n",
    " #   print(\"****** s *****\", s,\"index\", ind, \"inserting\",a)\n",
    "\n",
    "    # we now calculate RIM probablities for a, since a is ranked s in the profile S, the possible number of inversions for it can be any number in [1...ind]\n",
    "    RIM_a=[0]*(s+1)\n",
    "    for i in range(s+1):\n",
    "      RIM_a[i]=math.exp(-beta*w[sigma.index(a)]*i)\n",
    "    T=sum(RIM_a)\n",
    "    RIM_a=[i/T for i in RIM_a]\n",
    " #   print(\"RIM:\",RIM_a)\n",
    "    # case 1\n",
    "    if a not in A:\n",
    "  #    print(\"case 1, inserting\", a)\n",
    "      for j in range(s+1):\n",
    "        bef_j=sum(RIM_a[0:j])\n",
    "        af_j=sum(RIM_a[j+1:s+1])\n",
    "      #  print(\"bf\",RIM_a[0:j],bef_j)\n",
    "     #   print(\"af\",RIM_a[j+1:s+1],af_j)\n",
    "    #    print(\"check sum\", bef_j+af_j+RIM_a[j])\n",
    "        for i in range(m):\n",
    "       #   print(\"setting\", iter,\":array\", i,\":element\", j,\":loc\")\n",
    "       #   print(\"i\",i,\": af_j*DPtable[s-1][i][j]+ bef_j*DPtable[s-1][i][j-1]\",af_j,DPtable[s-1][i][j],bef_j,DPtable[s-1][i][j-1])\n",
    "          DPtable[s][i][j]=af_j*DPtable[s-1][i][j]+ bef_j*DPtable[s-1][i][j-1]\n",
    "\n",
    "\n",
    "    # case 2\n",
    "    if a in A:\n",
    "      loc=A.index(a)\n",
    "  #    print(\"case 2, inserting\",a)\n",
    "      for j in range(s+1):\n",
    "        for i in range(m):\n",
    "          if i!=loc:\n",
    "           # print(i)\n",
    "           # print(\"pr_af\",j,sum(RIM_a[j+1:s+1]))\n",
    "            DPtable[s][i][j]=DPtable[s-1][i][j]*sum(RIM_a[j+1:s+1])\n",
    "           # print(\"setting i\",i, \"and j\",j, \"elemt\", A[i],\"to\", DPtable[s-1][i][j]*sum(RIM_a[j+1:s+1]),\":\",DPtable[s-1][i][j],\"*\",sum(RIM_a[j+1:s+1]) )\n",
    "        C=FindSumSubarray(DPtable,s,j,m)\n",
    "        C=max(C,first_sum)\n",
    "\n",
    "      #  printDP(DPtable)\n",
    "      #  print(\"Setting loc\",loc,\"j\",j, \"to:\" ,RIM_a[j],\"*\",\"****C***\",C, \"s is:\",s)\n",
    "   #     print(\"loc\",loc)\n",
    "        DPtable[s][loc][j]=RIM_a[j]*C\n",
    "     #   print(\"print row\",DPtable[s][loc][0:s+1])\n",
    "       # printDP(DPtable)\n",
    "   #   print(\"now set it to zero\")\n",
    "     # print(\"just set the value for\",ind,\"a\",a,\"loc\",loc,\"at j\",j,\"to\",RIM_a[j]*C)\n",
    "      first_sum=0\n",
    "    ind=ind-1\n",
    "\n",
    " # printDP(DPtable)\n",
    "  finalDP=DPtable[k-1][0:m][0:k+1]\n",
    "  # finding p_perp\n",
    "  Atop=[i for i in A if i in sigma]\n",
    "  if(len(Atop)==0):\n",
    "    pr_perp=[1/m]*m\n",
    "  return finalDP, pr_perp\n",
    "\n",
    "\n",
    "def ConvDPtoProbVec(A,DP,pr_perp):\n",
    "  m=len(A)\n",
    "  ret_vec=[0]*m\n",
    "  for a in range(m):\n",
    "    for j in range(k):\n",
    "      ret_vec[a]=ret_vec[a]+DP[a][j]\n",
    "    ret_vec[a]=ret_vec[a]+pr_perp[a]\n",
    "   # print(\"prob: \",a ,pr_a_top,pr_perp)\n",
    "\n",
    "  return ret_vec\n",
    "\n",
    "\n",
    "def FindSumSubarray(D,s,j,m):\n",
    "  C=0\n",
    "  for kappa in range(j,k+1):\n",
    "    for a in range(m):\n",
    "      C=C+D[s-1][a][kappa]\n",
    "    #  print(\"summing all subarry to find C, now\",D[s-1][a][kappa])\n",
    "  return C"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning Alg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import random \n",
    "\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "List of algorithms:\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "List of common parameters:\n",
    "n is the number of items \n",
    "A: is a list, it shows the assortment; A should always be a subset of 1....n\n",
    "We always show the no choice option with 0\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "def Choice(Assortment, tau,n):\n",
    "\n",
    "    N=list(range(1,n+1))\n",
    "    if ( not set(Assortment)<= set(N)):\n",
    "        print(\"A must be subset of 1,2,..n\")\n",
    "        return None\n",
    "    Anull=Assortment+[0]\n",
    "    Nnull=N+[0]\n",
    "    if ( not set(tau)<= set(Nnull)):\n",
    "        print(\"tau must be subset of 0,1,2,..n: \", tau)\n",
    "        print(\"Nnull: \", Nnull)\n",
    "        return None\n",
    "\n",
    "    \n",
    "    for item in tau:\n",
    "        if item in Anull:\n",
    "          return item\n",
    "    return random.choice(Anull)\n",
    "  \n",
    "\n",
    "def LearnTopElement(A,C):\n",
    "    # Algorithm 3\n",
    "    # A is the assortment \n",
    "    # C is a set of choice samples \n",
    "    # two dimensional array X is filled so that at the end of the for loop we have:\n",
    "    # X_ij= # samples in which  A[i] was chosen and A[j] was not, we let A[r]=Null which is A[r]=0 \n",
    "    #print(\"A: \", A, \"C: \", C)\n",
    "    r = len(A)\n",
    "    A = A + [0]\n",
    "    num_picked_over = np.zeros((r+1, r+1))#[[0 for i in range(r+1)] for j in range(r+1)]\n",
    "    m=len(C)\n",
    "    for c in C:\n",
    "        i=A.index(c)\n",
    "        for j in range(r+1):\n",
    "            if j !=i:\n",
    "                num_picked_over[j][i] -= 1\n",
    "                num_picked_over[i][j] += 1\n",
    "                #print(\"incremented X\",i,j, \"because choide is\",c)\n",
    "\n",
    "    #print(\"Assortment: \", A)\n",
    "    #print(\"Count ith item picked over jth item:\\n \",num_picked_over)\n",
    "    num_picked_over_normalized = num_picked_over/m #[[X[i][j]/m for i in range(r+1)] for j in range(r+1)]\n",
    "    #print(\"Normalized count: \\n\",num_picked_over_normalized )\n",
    "    count_mat = num_picked_over_normalized - 1/(2*(r+1))\n",
    "    count_per_item = np.sum(count_mat >= 0, axis = 1)\n",
    "    max_item_index = np.argmax(count_per_item)\n",
    "    if count_per_item[max_item_index] >= r:\n",
    "      most_wanted_item = A[max_item_index]\n",
    "    else:\n",
    "      most_wanted_item = None\n",
    "\n",
    "    return most_wanted_item\n",
    "        \n",
    "    # # print(\"count_per_item: \", count_per_item)\n",
    "    # T=[None] * (r+1)\n",
    "    # for i in range(r+1):\n",
    "    #     count=0\n",
    "    #     for j in range(r+1):\n",
    "    #         if num_picked_over_normalized[i][j] >= 1/(2*(r+1)):\n",
    "    #             count=count+1\n",
    "    #     T[i]=count\n",
    "    # # print(\"all counts:\",T)\n",
    "    \n",
    "\n",
    "    # if T.count(r)==1:\n",
    "    #    return A[T.index(r)]\n",
    "    # if T.count(r)==0:\n",
    "    #    return None\n",
    "    # else: \n",
    "    #     \"print multiple elements have been found\"\n",
    "    #     return None \n",
    "    \n",
    "def BuCchoi(N,S):\n",
    "    n=len(N)\n",
    "    T=[]\n",
    "    for i in N:\n",
    "        A=[i]\n",
    "        C=[]\n",
    "        for tau in S:\n",
    "            c=Choice(A, tau,n)\n",
    "            C.append(c)\n",
    "        top=LearnTopElement(A,C)\n",
    "        # print(\"learned element, A=\",A ,\"is\",top)\n",
    "        if top!=None and top not in T and top!=0:\n",
    "            T=T+[top]\n",
    "     # sigma_dic maps each element in T to its learned rank\n",
    "    sigma_dic={}\n",
    "   # print(\"sigma\",sigma)\n",
    "    #print(\"learned top elements\",T)\n",
    "    for i in T:\n",
    "        l=0\n",
    "        for j in T:\n",
    "            if j!=i:\n",
    "                A=[i,j]\n",
    "                C = []\n",
    "                for tau in S:\n",
    "                  C.append(Choice(A, tau,n))\n",
    "                top=LearnTopElement(A,C)\n",
    "                #print(\"between {}, {} is preferred\".format(A, top))\n",
    "                if top==j:\n",
    "                    l=l+1\n",
    "       # print(\"i,l\",i,l)\n",
    "        sigma_dic[i]=l\n",
    "\n",
    "    return sigma_dic\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "\n",
    "        \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sushi Dataset experiments\n",
    "\n",
    "\n",
    "IDENTIFY CLUSTER CENTER (Given a set of top-k)\n",
    "A_1 = {1,2,...r} for now we assume r = 2.\n",
    "Let d = n/r + 2k (we use r = 2)\n",
    "For i = \t1 to d\n",
    "Display A_i to \\tau in train, and calculate prob_c(j,A_i) for j \\in A_i\n",
    "Using prob_c,generate A_{i+1} using information regarding A_i\n",
    "We have identified the center. \n",
    "\n",
    "train, test= 80/20%.\n",
    "IDENTIFY_CLUSTER_CENTER(train)\n",
    "p_choice = DYPCHIP(center, i,[100]) for all i \\in [100]\n",
    "Compare p_choice , p_test. \n",
    "Display A’s to test set, p_test = #i is ranked top. \n",
    "How about a mixture of centers?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sushi 2 dataset\n",
    "\n",
    "5000 top-10 ranking of 100 different sushi types\n",
    "\n",
    "Experiment setup: \n",
    "  - train, test = data partitioned 80,20\n",
    "  - num_clusters = {1,2,4,16,32}\n",
    "  - cluster train set into num_clusters clusters and learn different cluster centers using the BuCchoi \n",
    "  - define the distribution over permutations by defining the weight to be the size of cluster / total num permutations (=5000)\n",
    "  - calculate prob choice for each cluster using the DYPCHIP\n",
    "  - add probabilities using the weight of different clusters\n",
    "  - compare this probability to mnl probabilities\n",
    "\n",
    "mnl probabilities ~ count number of times item i is picked over 0. \n",
    "\n",
    "compare the two probabilities to p defined using the test set. p_i ~ number of times item i is ranked 1st in permutations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Traceback (most recent call last):\n",
       "  File \"/Users/sh1678/.vscode/extensions/ms-python.python-2025.2.0-darwin-arm64/python_files/python_server.py\", line 133, in exec_user_input\n",
       "    retval = callable_(user_input, user_globals)\n",
       "             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
       "  File \"<string>\", line 8, in <module>\n",
       "NameError: name 'time' is not defined. Did you forget to import 'time'\n",
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "file_dir = \"/Users/Sarachi/Documents/Codes/top_k_sampling/topkmallows-choices/00014_sushi/\"\n",
    "file_name = \"00014-00000002\"#, \"00014-00000001\"] #, \"00014-00000003.toi\"]\n",
    "file_type = \".soi\" #, \".soc\"]#, \".toi\"]\n",
    "\n",
    "n = 100 \n",
    "file_path = file_dir + file_name + file_type\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "preferences = read_preferences(file_path)\n",
    "# count number of misorders (I_i) and incomparables (P_i) separately and save to a file for future usecases. \n",
    "# misorders, incomparables = disagreements(preferences, n)\n",
    "# np.save(base_file_path + \"_misorders.npy\", misorders)\n",
    "# np.save(base_file_path + \"_incomparables.npy\", incomparables)\n",
    "# read number of misorders and incomparables from file\n",
    "dist_misorders = np.load(base_file_path + \"_misorders.npy\")\n",
    "dist_incomparables = np.load(base_file_path + \"_incomparables.npy\")\n",
    "\n",
    "#print(time.time() - start, \" finished \", file_name)\n",
    "\n",
    "# print(preferences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linkage completed in  25.783266067504883\n",
      "Learned cluster center of cluster 1:  {1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 100: 0}\n",
      "linkage completed in  24.50744605064392\n",
      "Learned cluster center of cluster 1:  {1: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 12: 0, 13: 0, 100: 0}\n",
      "Learned cluster center of cluster 2:  {1: 2, 2: 2, 4: 0, 5: 7, 6: 2, 7: 2, 9: 3, 10: 2, 11: 1, 12: 9, 13: 2, 14: 3, 15: 5, 16: 7, 17: 11, 19: 3, 20: 12, 22: 13, 23: 9, 27: 6, 28: 13, 31: 12, 41: 17, 60: 15, 63: 15, 83: 16, 97: 17, 99: 17, 100: 1}\n",
      "linkage completed in  26.36651110649109\n",
      "Learned cluster center of cluster 1:  {1: 1, 3: 1, 6: 1, 7: 1, 8: 0, 11: 1, 100: 1}\n",
      "Learned cluster center of cluster 2:  {1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 10: 0, 11: 0, 100: 0}\n",
      "Learned cluster center of cluster 3:  {3: 4, 4: 0, 5: 8, 6: 5, 8: 5, 9: 8, 10: 3, 11: 1, 12: 5, 13: 6, 15: 12, 16: 6, 17: 9, 18: 12, 19: 3, 24: 5, 25: 7, 26: 16, 27: 7, 31: 13, 34: 8, 39: 15, 47: 11, 49: 10, 56: 10, 59: 17, 62: 13, 64: 12, 65: 13, 77: 13, 88: 13, 89: 9, 90: 8, 100: 1}\n",
      "Learned cluster center of cluster 4:  {1: 1, 2: 2, 4: 0, 5: 5, 6: 2, 7: 7, 9: 1, 10: 8, 11: 8, 13: 1, 14: 4, 15: 3, 23: 11, 24: 13, 26: 13, 28: 11, 31: 9, 34: 14, 36: 13, 37: 12, 38: 12, 43: 13, 47: 15, 50: 10, 52: 16, 55: 10, 58: 14, 65: 14, 66: 14, 73: 12, 75: 13, 80: 14, 90: 16, 92: 13}\n"
     ]
    }
   ],
   "source": [
    "n = 100\n",
    "item_set = np.arange(1,n+1)\n",
    "num_preferences = len(preferences)\n",
    "train_size = int( num_preferences * .8)\n",
    "random.seed(42)\n",
    "all_indices = list(range(num_preferences))\n",
    "train_indices = random.sample(all_indices, train_size)\n",
    "test_indices = list(set(all_indices) - set(train_indices))\n",
    "train = [preferences[i] for i in train_indices]\n",
    "test = [preferences[i] for i in test_indices]\n",
    "for num_clusters in [1,2,4]:\n",
    "  \n",
    "  Z = cluster_preferences(train, dist_misorders[train_indices], dist_incomparables[train_indices], n = 100, p = 0.005, beta = 0.01)\n",
    "  labels = fcluster(Z, t=num_clusters, criterion='maxclust')\n",
    "\n",
    "  clusters = defaultdict(list)\n",
    "  for point, label in zip(train, labels):\n",
    "      clusters[label].append(point)\n",
    "\n",
    "  \n",
    "  item_probability = np.zeros(n)\n",
    "  for label in clusters: \n",
    "      cluster_weight = len(clusters[label]) / train_size\n",
    "      learned_cluster_center = BuCchoi(range(1,n+1),clusters[label])\n",
    "      print(\"Learned cluster center of cluster {}: \".format(label), learned_cluster_center)\n",
    "      # Calculate choice probability of each item using DYCHIP\n",
    "      item_probability_for_cluster = np.ones(n)/n #DYCHIP(learned_cluster_center, clusters[label])\n",
    "      item_probability += item_probability_for_cluster * cluster_weight\n",
    "  \n",
    "  # compare probability to mnl, test?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sigma_dic)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.     0.651  0.6322 0.6366 0.631  0.6362 0.638  0.6414 0.6274 0.6194\n",
      " 0.6434 0.6178 0.6192 0.614  0.6104 0.611  0.6096 0.6006 0.5978 0.591\n",
      " 0.5936 0.5912 0.57   0.582  0.5824 0.5872 0.5834 0.5826 0.567  0.5664\n",
      " 0.5678 0.5604 0.5558 0.561  0.5524 0.5634 0.5466 0.5624 0.5438 0.5394\n",
      " 0.5494 0.5454 0.529  0.5358 0.5282 0.5308 0.5376 0.5362 0.534  0.5382\n",
      " 0.547  0.5284 0.5284 0.5374 0.5236 0.5146 0.531  0.5236 0.5472 0.5152\n",
      " 0.5242 0.5192 0.525  0.524  0.5232 0.517  0.5138 0.5104 0.5162 0.505\n",
      " 0.5086 0.5212 0.4988 0.5194 0.5152 0.5168 0.5112 0.5124 0.5108 0.5186\n",
      " 0.5176 0.4988 0.5006 0.5114 0.4988 0.506  0.5152 0.5126 0.5076 0.5176\n",
      " 0.5012 0.5118 0.51   0.525  0.4874 0.506  0.5036 0.5132 0.517  0.5212\n",
      " 0.6504]\n",
      "[0.01785976 0.0116267  0.01129094 0.01136952 0.01126951 0.01136238\n",
      " 0.01139453 0.01145525 0.01120521 0.01106233 0.01149097 0.01103376\n",
      " 0.01105876 0.01096589 0.0109016  0.01091231 0.01088731 0.01072657\n",
      " 0.01067656 0.01055512 0.01060155 0.01055869 0.01018006 0.01039438\n",
      " 0.01040152 0.01048725 0.01041938 0.0104051  0.01012648 0.01011577\n",
      " 0.01014077 0.01000861 0.00992645 0.01001932 0.00986573 0.01006219\n",
      " 0.00976214 0.01004433 0.00971214 0.00963355 0.00981215 0.00974071\n",
      " 0.00944781 0.00956926 0.00943352 0.00947996 0.00960141 0.0095764\n",
      " 0.00953711 0.00961212 0.00976929 0.0094371  0.0094371  0.00959783\n",
      " 0.00935137 0.00919063 0.00948353 0.00935137 0.00977286 0.00920135\n",
      " 0.00936209 0.00927279 0.00937637 0.00935851 0.00934423 0.00923349\n",
      " 0.00917634 0.00911562 0.00921921 0.00901918 0.00908347 0.00930851\n",
      " 0.00890845 0.00927636 0.00920135 0.00922992 0.00912991 0.00915134\n",
      " 0.00912276 0.00926207 0.00924421 0.00890845 0.00894059 0.00913348\n",
      " 0.00890845 0.00903704 0.00920135 0.00915491 0.00906561 0.00924421\n",
      " 0.00895131 0.00914062 0.00910848 0.00937637 0.00870485 0.00903704\n",
      " 0.00899417 0.00916563 0.00923349 0.00930851 0.01161599]\n"
     ]
    }
   ],
   "source": [
    "#MNL approach\n",
    "n = 100\n",
    "fixed_item = 0\n",
    "w = np.zeros(n+1)\n",
    "w[0] = 1\n",
    "for i in range(1, n+ 1):\n",
    "    A = [i]\n",
    "    count = np.sum([1 for tau in preferences if Choice(A,tau,n) != 0])\n",
    "    w[i] = count / (1.0 * len(preferences))\n",
    "\n",
    "print(w)\n",
    "w_total = np.sum(w)\n",
    "p = w / w_total\n",
    "print(p)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:top_k_sampling]",
   "language": "python",
   "name": "conda-env-top_k_sampling-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
