{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#imports\n",
    "from bisect import bisect_left\n",
    "from bisect import bisect_right\n",
    "from scipy.linalg import svd\n",
    "from scipy.linalg import eig\n",
    "from numpy import *\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pyplot import figure\n",
    "import matplotlib\n",
    "from sklearn.utils.extmath import randomized_svd\n",
    "import operator\n",
    "from collections import Counter\n",
    "from scipy.sparse.linalg import eigs\n",
    "import numpy.linalg as linalg\n",
    "from scipy.linalg import svd\n",
    "from scipy.stats import bernoulli\n",
    "from random import randint\n",
    "from sklearn.utils.extmath import randomized_svd\n",
    "import plotly.express as px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Gen_graph(n,p,q,kl):\n",
    "\n",
    "\n",
    "    nkl=[sum(kl[0:i]) for i in range(len(kl)+1)]\n",
    "    starter=np.random.permutation(n)\n",
    "\n",
    "\n",
    "\n",
    "    #randomly permute the cluster identities.\n",
    "    clusters=[]\n",
    "    for i in range(len(kl)):\n",
    "        clusters.append(list(starter[int(nkl[i]):int(nkl[i+1])]))\n",
    "\n",
    "    mean_M=np.zeros((n,n))\n",
    "    M=np.zeros((n,n))\n",
    "\n",
    "\n",
    "    #Generate the expectation matrix of SBM graph\n",
    "    for i in range(len(clusters)):\n",
    "        for j in clusters[i]:\n",
    "            for k in range(len(clusters)):\n",
    "                if(i==k):\n",
    "                    for t in clusters[k]:\n",
    "                        mean_M[j,t]=p\n",
    "                else:\n",
    "                    for t in clusters[k]:\n",
    "                        mean_M[j,t]=q\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    #Generate the SBM graph\n",
    "    for i in range(n):\n",
    "        for j in range(i):\n",
    "            M[i,j]=bernoulli.rvs(mean_M[i,j])\n",
    "            M[j,i]=M[i,j]\n",
    "\n",
    "    for i in range(n):\n",
    "        M[i,i]=bernoulli.rvs(p)\n",
    "\n",
    "\n",
    "    return M,clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Estimate size of largest cluster\n",
    "def Estimate(G,n,p,q,W,Y2):\n",
    "\n",
    "    N_w=[]\n",
    "    u_list=[]\n",
    "\n",
    "    u=randint(n//8,n//4)\n",
    "    #print(u)\n",
    "\n",
    "    for i in range(100):\n",
    "\n",
    "        u=randint(n//8,n//4)\n",
    "        u_list.append(u)\n",
    "        c=0\n",
    "        for j in range(n//2,n):\n",
    "            if(G[u,j]==1):\n",
    "                c=c+1\n",
    "        N_w.append(c)\n",
    "\n",
    "    #print(max(N_w))\n",
    "\n",
    "    s_p=int((max(N_w)-q*n//2)/(p-q))\n",
    "\n",
    "    #print(s_p)\n",
    "         \n",
    "    L_e=sqrt(0.2)*(p-q)*sqrt(s_p)\n",
    "\n",
    "    return L_e,s_p\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Recover V_i \\cap W\n",
    "def identify_cluster_W(M,S,s_p,p,q):\n",
    "    \n",
    "    S1=[]\n",
    "    for i in range(n//2,n):\n",
    "        c=0\n",
    "        for j in S:\n",
    "            c=c+M[i,j]\n",
    "\n",
    "        #if(c>(q*len(S)+(p-q)*s_p/8)):\n",
    "\n",
    "        if(c>(q*len(S)+(p-q)*s_p/8)):\n",
    "            S1.append(i)\n",
    "       \n",
    "\n",
    "    return S1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Recover V_i \\cap U\n",
    "def identify_cluster_U(M,S1,p,q):\n",
    "\n",
    "\n",
    "    S2=[]\n",
    "    for i in range(0,n//2):\n",
    "        c=0\n",
    "        for j in S1:\n",
    "            c=c+M[i,j]\n",
    "\n",
    "        #if(c>(q*len(S1)+(p-q)*len(S1)/4)):\n",
    "        if(c>(q*len(S1)+(p-q)*len(S1)/4)):    \n",
    "            S2.append(i)\n",
    "       \n",
    "\n",
    "    return S2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Cluster sizes, probabilities and instantiation.\n",
    "\n",
    "#[800,300,350,25,15,1,1,1,1,1,1,1,1,1,1]\n",
    "#[1500,800,300,350,25,15,1,1,1,1,1,1,1,1,1,1]\n",
    "#[2000,800,300,350,25,15,1,1,1,1,1,1,1,1,1,1]\n",
    "\n",
    "#Ailon mid-size\n",
    "#[500,150,70,30]\n",
    "#[800,550,70,30]\n",
    "\n",
    "\n",
    "set_sizes=[1000, 903]\n",
    "for i in range(997):\n",
    "    set_sizes.append(1)\n",
    "\n",
    "p=0.7\n",
    "q=0.3\n",
    "k_Vu=len(set_sizes)\n",
    "n=sum(set_sizes)\n",
    "\n",
    "\n",
    "M,clusters=Gen_graph(n,p,q,set_sizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Start the steps and estimate\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "s_star=sqrt(p*(1-q))*sqrt(n)*log(n)/(p-q)\n",
    "k=(p-q)*sqrt(n)/sqrt(p*(1-q))\n",
    "k=int(k)\n",
    "L_e,s_p=Estimate(M,n,p,q,None,None)\n",
    "print(\"L_e=\",L_e,\" s_p=\",s_p,\" k=\",k,\" s*=\",s_star)\n",
    "\n",
    "A=M[n//4:n//2,0:n//8]\n",
    "B=M[n//4:n//2,n//8:n//4]\n",
    "\n",
    "\n",
    "choice=0\n",
    "clusters1=set(clusters[choice])\n",
    "W_set=set(list(range(n//2,n)))\n",
    "U_set=set(list(range(0,n//2)))\n",
    "\n",
    "\n",
    "#k' dimensional projection of B\n",
    "lU,s,_=randomized_svd(A,n_components=k, n_iter=5, random_state=None)\n",
    "#print(s)\n",
    "Pk_B=np.dot(lU.T,B)\n",
    "#print(Pk_B.shape)\n",
    "\n",
    "\n",
    "\n",
    "counter=0\n",
    "ch=0\n",
    "print(\"root(n)log(n)=\",sqrt(n)*log2(n))\n",
    "\n",
    "\n",
    "\n",
    "while(ch!=1 and counter<=300):\n",
    "\n",
    "    #print(\"Counter=\",counter)\n",
    "    fail=0\n",
    "    #choose the random vertex\n",
    "    x=random.choice(list(range(n//8,n//4)))\n",
    "    x2=x-n//8\n",
    "    S=[]\n",
    "\n",
    "\n",
    "    #Construct plural set\n",
    "    for j in range(0,n//8):\n",
    "        tval=np.linalg.norm(Pk_B[:,x2]-Pk_B[:,j])\n",
    "        if(tval<L_e):\n",
    "            S.append(j+n//8)\n",
    "\n",
    "    \n",
    "    #Recover V_i intersection W\n",
    "    S1=identify_cluster_W(M,S,s_p,p,q)\n",
    "\n",
    "\n",
    "    #size check.\n",
    "    if(len(S1)<0.25*s_p):\n",
    "        fail=fail+1\n",
    "        counter=counter+1\n",
    "        continue\n",
    "\n",
    "\n",
    "    #Check if there are any low degree vertices in the subgraph induced by S1\n",
    "    minn_c=len(S1)\n",
    "    for i in S1:\n",
    "        c=0\n",
    "        for j in S1:\n",
    "            if(M[i,j]==1):\n",
    "                c=c+1\n",
    "        if(c<minn_c):\n",
    "            minn_c=c\n",
    "\n",
    "        if(minn_c<((0.8*p+0.2*q)*len(S1))):\n",
    "            fail=fail+1\n",
    "            counter=counter+1    \n",
    "            break\n",
    "\n",
    "\n",
    "    #Check if there are any vertices in W\\S1 with high degree in S1\n",
    "    iter1=set(list(range(n//2,n))).difference(set(S1))\n",
    "\n",
    "    max_c=0\n",
    "    for i in iter1:\n",
    "        c=0\n",
    "        for j in S1:\n",
    "            if(M[i,j]==1):\n",
    "                c=c+1\n",
    "                 \n",
    "        if(c>((0.8*p+0.2*q)*len(S1))):\n",
    "            fail=fail+1\n",
    "            counter=counter+1\n",
    "            break\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    if(fail==0):\n",
    "        ch=1\n",
    "\n",
    "    counter=counter+1\n",
    "\n",
    "\n",
    "\n",
    "if(ch==1):\n",
    "    #Recover V_i intersection U\n",
    "    S2=identify_cluster_U(M,S1,p,q)\n",
    "    V=set(S1).union(set(S2))\n",
    "    \n",
    "    print(len(S1),len(set(S1).intersection(clusters1)), len(W_set.intersection(clusters1)))\n",
    "    print(len(S2),len(set(S2).intersection(clusters1)), len(U_set.intersection(clusters1)))\n",
    "    print(\"Finished with=\",len(V),\" size of largest cluster=\",len(clusters1),)\n",
    "    print(\"Intersections of final set with hidden clusters=\",end=' ')\n",
    "    for i in range(len(clusters)):\n",
    "        print(len(V.intersection(set(clusters[i]))),end=', ')   \n",
    "    print(\"\\nIntersections of the plural set with hidden clusters=\",end=' ')\n",
    "    for i in range(len(clusters)):\n",
    "        print(len(set(S).intersection(set(clusters[i]))),end=', ')\n",
    "\n",
    "    print(\"\\n\",len(set(S)),len(set(S).intersection(clusters1)),len(set(list(range(n//8,n//4))).intersection(clusters1)))\n",
    "\n",
    "else:\n",
    "    print(\"No recoverable cluster\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3fe79e683bbaf8df0df55ec2b475f92db7ab22b7b80f78586f3eac5e4f96af26"
  },
  "kernelspec": {
   "display_name": "Python 3.9.13 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
