# This code attempts to recover the support of the unknown vectors.
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import itertools as it
import numpy as np
import scipy
import random



def oracle(v):
    if np.random.random() >= 2.0/3:
        return np.dot(v, beta_1)
    elif np.random.random() < 2.0/3 and np.random.random() >= 1.0/3:
        return np.dot(v, beta_2)  
    else:
        return np.dot(v, beta_3)    

def RUFF(n, m, d):
    X = scipy.sparse.random(m, 1, density = (d*1.0)/m).A
    for i in range(n-1):
        X = np.hstack((X, scipy.sparse.random(m, 1, density = (d*1.0)/m).A))
    return X    

def count_nonzero(v):
    count = 0
    for i in range(T):
        if oracle(v) != 0:
            count += 1
    return round((3.0*count)/T)         


def modified_query(indices):
    v = np.zeros(n)
    for i in indices:
        v[i] = np.random.random()
    return int(count_nonzero(v))    


def intersection(i, j, k, occ_1, occ_2):
    temp = -occ_1[i]-occ_1[j]-occ_1[k]
    temp = temp + occ_2[i,j]+occ_2[j,k]+occ_2[i,k]
    temp = temp + modified_query([i,j,k])
    return temp

def intersection2(i, j, occ_1):
    temp = occ_1[i]+occ_1[j]
    temp = temp -modified_query([i,j])
    return temp    


def check(tup, supp_vectors):
    flag = True
    for v in supp_vectors:
        if v[tup[0]] == 1 and v[tup[1]] == 1:
            flag = False
            break
    return flag        

def check2(tup, supp_vectors):
    flag = True
    for v in supp_vectors:
        if v[tup[0]] == 1 and v[tup[1]] == 0:
            flag = False
            break
    return flag    

def recover_identifiable_sub(occ_2, occ_3, occ_2c, occ_3c, supp_vectors, c1_final):

    flag = 0
    supp = np.zeros(n)
    for tup in list(it.combinations(c1_final,2)):
        if occ_2[tup] == 1 and check(tup, supp_vectors):
            #print(tup)
            pivot = tup; flag = 1 
            break
    if flag == 1:           
        supp = np.zeros(n)
        for i in c1_final:
            if occ_3[tup[0],tup[1],i] == 1:
                supp[i] = 1
    if flag == 0:
        for tup in list(it.combinations(c1_final,2)):
            if occ_2c[tup] == 1 and check2(tup, supp_vectors):
                #print(tup)
                pivot = tup; flag = 1
                break
        if flag == 1:            
            supp = np.zeros(n); supp[tup[0]] = 1; supp[tup[1]] = 0
            for i in c1_final:
                if i != tup[0] or i != tup[1]: 
                    if occ_3c[tup[0],tup[1],i] == 1:
                        supp[i] = 1                        
    return supp                  
        



def recover_identifiable(occ_2, occ_3, occ_2c, occ_3c, c1_final):
    supp_vectors = []
    supp_1 = recover_identifiable_sub(occ_2, occ_3,  occ_2c, occ_3c, supp_vectors, c1_final); supp_vectors = [supp_1]
    supp_2 = recover_identifiable_sub(occ_2, occ_3,  occ_2c, occ_3c, supp_vectors, c1_final); supp_vectors = [supp_1, supp_2]
    supp_3 = recover_identifiable_sub(occ_2, occ_3,  occ_2c, occ_3c, supp_vectors, c1_final)
    return supp_1, supp_2, supp_3


def recover_jennrich(occ_3):
    a = np.random.rand(n)
    b = np.random.rand(n)
    T1 = sum([a[i]*occ_3[i] for i in range(n)])
    T2 = sum([b[i]*occ_3[i] for i in range(n)])
    w, v = scipy.linalg.eig(T2.dot(np.linalg.pinv(T1)))
    return w, v
    

def recover_support(n, k, m, d, beta_1, beta_2, beta_3, alpha, T):
    X = RUFF(n,m,d)

    b = np.zeros([4,m])

    for i in range(m): 
        count = int(count_nonzero(X[i]))
        for j in range(0,count+1):
            b[j,i] = 1


    c1 = [np.dot(b[1],X[:,i]) for i in range(n)]
    c2 = [np.dot(b[2],X[:,i]) for i in range(n)]
    c3 = [np.dot(b[3],X[:,i]) for i in range(n)]

    c1_final = [i for i in range(len(c1)) if c1[i]>0.4*d]
    c2_final = [i for i in range(len(c2)) if c2[i]>0.4*d]
    c3_final = [i for i in range(len(c3)) if c3[i]>0.4*d]

    occ_1 = np.zeros(n)
    for i in range(n):
        if i in c1_final and i not in c2_final:
            occ_1[i] = 1
        elif i in c2_final and i not in c3_final:
            occ_1[i] = 2
        elif i in c3_final:
            occ_1[i] = 3

    occ_2 = np.zeros((n,n))
    for i in c1_final:
        for j in c1_final:
            occ_2[i,j] = occ_2[j,i] = intersection2(i, j, occ_1)

    occ_2c = np.zeros((n,n))
    for i in c1_final:
        for j in c1_final:
            occ_2c[i,j] = occ_1[i]-occ_2[i,j]
            occ_2c[j,i] = occ_1[j]-occ_2[i,j]        

    occ_3 = np.zeros((n,n,n))
    
    for i in c1_final:
        for j in c1_final:
            for k in c1_final:
                occ_3[i,j,k] = occ_3[j,k,i] = occ_3[j,i,k] = occ_3[i,k,j] = occ_3[k,j,i] = occ_3[k,i,j] = intersection(i, j, k, occ_1, occ_2)

    occ_3c = np.zeros((n,n,n))
    for i in c1_final:
        for j in c1_final:
            for k in c1_final:
                occ_3c[i,j,k] = occ_2[i,k]-occ_3[i,j,k]
                occ_3c[j,k,i] = occ_2[j,i]-occ_3[j,k,i]
                occ_3c[i,k,j] = occ_2[i,j]-occ_3[i,k,j]
                occ_3c[k,j,i] = occ_2[k,i]-occ_3[k,j,i]
                occ_3c[j,i,k] = occ_2[j,k]-occ_3[j,i,k]
                occ_3c[k,i,j] = occ_2[k,j]-occ_3[k,i,j]



    supp_1, supp_2, supp_3 = recover_identifiable(occ_2, occ_3, occ_2c, occ_3c, c1_final)
    return supp_1, supp_2, supp_3, occ_3c, occ_3, occ_2, occ_1, occ_2c


#Uncomment L181-L205 for running Algorithm 1
'''
n = 500; c =5; k = 10; alpha = 0.2; T_range = [5,10,15,20,25,30,35,40,45,50]
mean_temp = []; var_temp = []

for T in T_range:
    indices1 =  random.sample(range(n), 5)
    indices2 = list(set(indices1[:2]+random.sample(range(n), 3)))
    indices3 = list(set(indices1+indices2))
    beta_1 = np.zeros(n); beta_1[indices1] = np.random.rand(len(indices1))
    beta_2 = np.zeros(n); beta_2[indices2] = np.random.rand(len(indices2))
    beta_3 = np.zeros(n); beta_3[indices3] = np.random.rand(len(indices3))
    exact = []
    for t in range(100):
        print(T, t)
        m = int(c*k*k*int(np.log(n))/(alpha**2))
        d = int(c*k*int(np.log(n))/alpha)
        supp_1, supp_2, supp_3, occ_3c, occ_3, occ_2, occ_1, occ_2c = recover_support(n, k, m, d, beta_1, beta_2, beta_3, alpha, T)
        if set(np.nonzero(supp_1)[0]) == set(indices3) and set(np.nonzero(supp_2)[0]) == set(indices2) and set(np.nonzero(supp_3)[0]) == set(indices1):
            exact.append(1)
        elif set(np.nonzero(supp_1)[0]) == set(indices3) and set(np.nonzero(supp_2)[0]) == set(indices1) and set(np.nonzero(supp_3)[0]) == set(indices2):       
            exact.append(1)
        else:
            exact.append(0)    
    mean_temp.append(np.mean(exact))
    var_temp.append(np.var(exact)) 
    print(mean_temp, var_temp)       

print(np.nonzero(supp_1), indices3)
print(np.nonzero(supp_2), indices2)
print(np.nonzero(supp_3), indices1)
'''

#Uncomment L214-L242 for running Algorithm 8

n = 500; c = 5; k = 10; alpha = 0.2; T_range = [5,10,15,20,25,30,35,40,45,50]
mean_temp = []; var_temp = []

for T in T_range:
    indices1 =  random.sample(range(n), 5)
    indices2 = list(set(indices1[:2]+random.sample(range(n), 3)))
    indices3 = list(set(indices1+indices2))
    beta_1 = np.zeros(n); beta_1[indices1] = np.random.rand(len(indices1))
    beta_2 = np.zeros(n); beta_2[indices2] = np.random.rand(len(indices2))
    beta_3 = np.zeros(n); beta_3[indices3] = np.random.rand(len(indices3))
    exact = []
    for t in range(100):
        print(T,t)
        m = int(c*k*k*int(np.log(n))/(alpha**2))
        d = int(c*k*int(np.log(n))/alpha)
        supp_1, supp_2, supp_3, occ_3c, occ_3, occ_2, occ_1, occ_2c = recover_support(n, k, m, d, beta_1, beta_2, beta_3, alpha, T)
        w, v = recover_jennrich(occ_3)
        support_set = []
        for i in range(8):
            temp = set([j for j in range(len(v[:,i])) if np.abs(v[j,i]) > 0.002])
            if temp not in support_set:
                support_set.append(temp)
        if (set(indices1) in support_set) and (set(indices2) in support_set) and (set(indices3) in support_set):
            exact.append(1)
        else:
            exact.append(0)
    mean_temp.append(np.mean(exact))
    var_temp.append(np.var(exact)) 
    print(mean_temp, var_temp)       

  







