import numpy as np 
#from fcmeans import FCM 
from sklearn.datasets import make_blobs 
from sklearn.datasets import load_iris #Comment out this line if you want to run on wine dataset
#from sklearn.datasets import load_wine  #Comment out this line if you want to run on iris dataset
import matplotlib.pyplot as plt
import random 


d = 4
k=3
#alpha = 2
#multiplier = 5


def compute_mean(X, sampled_indices, fcm_labels, index, alpha):
    mean_num = np.zeros(d)
    mean_den = 0.0
    for i in sampled_indices:
        mean_num += (fcm_labels[i,index]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,index]**(alpha)
    return mean_num/mean_den      


def permutation(X, hatmu):
    #print(X[0].shape, hatmu.shape)
    perm = []
    for i in range(n_samples):
        dist = np.dot(X[i]-hatmu,X[i]-hatmu)
        perm.append((dist, i))
    perm = sorted(perm, key=lambda x:x[0])
    idx = [x[1] for x in perm]
    return(idx)


def binary_search(x, idx, X, fcm_labels, cluster_index):
    low = 0; high = n_samples-1
    while low!=high: 
          mid = int((low+high)/2)+((low+high)%2)
          if low > high:
             low = None
             break
          if fcm_labels[idx[mid], cluster_index] >= x:
             low = mid
          else:
             high = mid-1
    return low                 

def compute_membership(X, hatmu, fcm_labels, eta, n_samples, cluster_index):
    hatu = np.zeros(n_samples)
    idx = permutation(X, hatmu)
    val = []
    flag = 0
    eta_arr = np.linspace(0,1,int(1.0/eta)) 
    for x in eta_arr:
        value = binary_search(round(x,3), idx, X, fcm_labels, cluster_index)
        val.append(value)
    for i in range(len(val)-1):
        for j in range(val[i],val[i+1],-1):
            hatu[idx[j]] = eta_arr[i]
    z  = fcm_labels[idx[0], cluster_index]
    minimum = np.infty
    for i in eta_arr:
        if z-i < minimum and z-i > 0:
           minimum = i    
    hatu[idx[0]] = minimum  
    #print(val)     
    return hatu               

    
def estimate_together(X, fcm_labels, m, n_samples, eta, k, alpha):
    indices = [i for i in range(n_samples)]
    sampled_indices = np.random.choice(indices, size=m, replace=False)
    hatmu = np.empty((d,k))
    hatu = np.empty((n_samples,k))
    for i in range(k):
        hatmu[:,i] = compute_mean(X, sampled_indices, fcm_labels, i, alpha)
    for i in range(k):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta, n_samples, i)
    return hatmu, hatu  
    

def compute_first(X, fcm_labels, m, n_samples, eta, k, alpha):
    argmax = -1
    indices = [i for i in range(n_samples)]
    sampled_indices = random.choices(indices, k=m)
    val = np.zeros(k)
    for index in range(k):
        val[index] = np.sum([fcm_labels[i,index]**(alpha) for i in sampled_indices])
    first = np.argmax(val)
    mean_num = np.zeros(d)
    mean_den = 0.0
    for i in sampled_indices:
        mean_num += (fcm_labels[i,first]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,first]**(alpha)
    return mean_num/mean_den, first        


def compute_sets(X, fcm_labels, estimate_index, partition, eta_1, clusters, index):
    membership = compute_membership(X, estimate_index, fcm_labels, eta_1, n_samples, index)
    partition_new = {}
    for i in np.linspace(0,1,int(1.0/eta_1)):
        partition_new[round(i,3)] = []
    for i in np.linspace(0,1,int(1.0/eta_1)):
        for j in partition[round(i,3)]:
            #print(i, j, membership[j], clusters, index)
            #print(partition_new.keys())
            partition_new[min(1.0,round(i+membership[j]),3)].append(j)
    return partition_new             

def compute_index(X, fcm_labels, m, n_samples, partition, clusters, eta_1, k, alpha):
    r = (1.0*m)/np.sum([1 for s in np.linspace(0,1,int(1.0/eta_1)) if len(partition[round(s,3)])>0])
    r = int(r)
    argmax = -1
    mean_num = np.zeros(d)
    mean_den = 0.0
    val = np.zeros(k)
    for index in clusters:
        val[index] = -np.inf
    sampled_indices = {}  
    for s in np.linspace(0,1,int(1.0/eta_1)):
        factor = (1.0*len(partition[round(s,3)]))/r
        if len(partition[round(s,3)]) > 0:
            sampled_indices[round(s,3)] = random.choices(partition[round(s,3)], k=r)    
            for index in range(k):
                if index not in clusters:
                    val[index] += factor*np.sum([fcm_labels[i,index]**(alpha) for i in sampled_indices[round(s,3)]])
    first = np.argmax(val)
    for s in np.linspace(0,1,int(1.0/eta_1)):
        factor = (1.0*len(partition[round(s,3)]))/r
        if len(partition[round(s,3)]) > 0:
            for i in sampled_indices[round(s,3)]:
                mean_num += factor*(fcm_labels[i,first]**(alpha))*X[i]
                mean_den +=  factor*fcm_labels[i,first]**(alpha)
    return  mean_num/mean_den, first       

def estimate_sequentially(X, fcm_labels, m, n_samples, eta_1, eta_2, k, alpha):
    indices = [i for i in range(n_samples)]
    hatmu = np.empty((d,k))
    hatu = np.empty((n_samples,k))
    estimate_first, first = compute_first(X, fcm_labels, m, n_samples, eta_1, k, alpha)
    clusters = set()
    clusters.add(first)
    hatmu[:,first] = estimate_first
    partition = {}
    for s in np.linspace(0,1,int(1.0/eta_1)):
        partition[round(s,3)] = []
    partition[0] = [i for i in range(n_samples)]
    partition = compute_sets(X, fcm_labels, estimate_first, partition, eta_1, clusters, first) 
    for i in range(k-1):
        #print(clusters)
        estimate_index, index = compute_index(X, fcm_labels, m, n_samples, partition, clusters, eta_1, k, alpha)
        clusters.add(index)
        hatmu[:,index] = estimate_index
        partition = compute_sets(X, fcm_labels, estimate_index, partition, eta_1, clusters, index)
    for i in range(k):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta_2, n_samples, i)
    return hatmu, hatu     

def objective_function(X, alpha, hatmu, hatu, n_samples, k):
    for i in range(n_samples):
        s = 1-np.sum(hatu[i])
        for j in range(len(hatu[i])):
            hatu[i,j] = hatu[i,j]+((s*1.0))/k
    obj = 0.0 
    for i in range(n_samples):
        for j in range(k):
            obj += (hatu[i,j]**(alpha))*np.dot(X[i]-hatmu[:,j],X[i]-hatmu[:,j])        
    return obj/n_samples

epsilon_1 = np.zeros((11,40))
epsilon_2 = np.zeros((11,40))
obj1 = np.zeros((11,40))
obj2 = np.zeros((11,40))
queries = np.zeros((11,40))
obj = np.zeros(11)

data = load_iris()  #Comment out this line if you want to run on wine dataset
#data = load_wine() #Comment out this line if you want to run on iris dataset
X = data['data']
n_samples = X.shape[0]



  
def accuracy(alpha):          
    acc1 = []
    acc2 = []
    queries = []
    labels = np.empty((X.shape[0],k))
    for i in range(X.shape[0]):
        labels[i,data['target'][i]] = 0.99
        for j in range(k):
            if j!=data['target'][i]:
               labels[i,j]=0.005
    fcm_labels = labels                        
    for y, factor in enumerate([i*3 for i in range(1,41,2)]):
        print(y)
        factor2 = int((factor*2.0)/(k+1))
        queries.append(factor*k)
        temp1 = []
        temp2 = []
        for t in range(500):
            #print(t)
            hatmu, hatu = estimate_together(X, fcm_labels, factor, n_samples, 0.1, k, alpha) 
            hatmu2, hatu2 = estimate_sequentially(X, fcm_labels, factor2, n_samples, 0.1, 0.1, k, alpha)
            count = 0
            for i in range(X.shape[0]):
                if np.argmax(hatu[i]) != data['target'][i]:
                    count += 1
            temp1.append((1.0*count)/X.shape[0]) 
            count = 0
            for i in range(X.shape[0]):
                if np.argmax(hatu2[i]) != data['target'][i]:
                    count += 1
            temp2.append((1.0*count)/X.shape[0])
        acc1.append(np.mean(temp1))
        acc2.append(np.mean(temp2))            
    return queries, acc1, acc2    



linestyles = ['-', '--', '-.', ':']

markers=['bo-','rD-','g*-','b-','c2']

#plt.plot(2+np.arange(10), epsilon_1[:,0], markers[0],markersize=0,label="Proc. 1 (Q = 8)",linewidth=4, linestyle=linestyles[0])
#plt.plot(2+np.arange(10), epsilon_2[:,0], markers[1],markersize=0,label="Proc. 4 (Q = 8)",linewidth=4, linestyle=linestyles[1])


accuracy_list = []

for alpha in [1,2,3,4,5]:
    queries, acc1, accuracy_temp = accuracy(alpha)
    accuracy_list.append(accuracy_temp)
    print(queries[:20], 100*(1-np.array(accuracy_temp[:20]))) 
    

plt.plot(queries[:20], 100*(1-np.array(accuracy_list[0][:20])), markers[0],markersize=8,label=r'$\alpha=1$',linewidth=1)                                                                                                                        
plt.plot(queries[:20], 100*(1-np.array(accuracy_list[1][:20])), markers[1],markersize=8,label=r'$\alpha=2$',linewidth=1)                                                                                                                        
plt.plot(queries[:20], 100*(1-np.array(accuracy_list[2][:20])), markers[2],markersize=8,label=r'$\alpha=3$',linewidth=1)                                                                                                                        
plt.plot(queries[:20], 100*(1-np.array(accuracy_list[3][:20])), markers[3],markersize=8,label=r'$\alpha=4$',linewidth=1)                                                                                                                        
plt.plot(queries[:20], 100*(1-np.array(accuracy_list[4][:20])), markers[4],markersize=8,label=r'$\alpha=5$',linewidth=1)                                                                                                                        



#plt.plot(queries[:20], 100*(1-np.array(acc1[:20])), markers[0],markersize=8,label="Procedure 1",linewidth=1)    
#plt.plot(queries[:20], 100*(1-np.array(acc2[:20])), markers[1],markersize=8,label="Procedure 4",linewidth=1)                                                                                                                        



plt.legend(fontsize=10)
#plt.legend(bbox_to_anchor=(0., 0.99, 1., .102), loc=3,
#           ncol=3, mode="expand", borderaxespad=0.)
plt.xlabel(r'Number of queries',fontsize=10)
plt.ylabel(r'Accuracy',fontsize=10)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
plt.ylim([77, 84])
plt.grid()
plt.show()                       
plt.close()



