import numpy as np 
from fcmeans import FCM 
from sklearn.datasets import make_blobs 
import random 
import matplotlib.pyplot as plt



d = 10
alpha = 2
multiplier = 5


def compute_mean(X, sampled_indices, fcm_labels, index):
    mean_num = np.zeros(d)
    mean_den = 0.0
    for i in sampled_indices:
        mean_num += (fcm_labels[i,index]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,index]**(alpha)
    return mean_num/mean_den      


def permutation(X, hatmu):
    #print(X[0].shape, hatmu.shape)
    perm = []
    for i in range(n_samples):
        dist = np.dot(X[i]-hatmu,X[i]-hatmu)
        perm.append((dist, i))
    perm = sorted(perm, key=lambda x:x[0])
    idx = [x[1] for x in perm]
    return(idx)


def binary_search(x, idx, X, fcm_labels, cluster_index):
    low = 0; high = n_samples-1
    while low!=high: 
          mid = int((low+high)/2)+((low+high)%2)
          if low > high:
             low = None
             break
          if fcm_labels[idx[mid], cluster_index] >= x:
             low = mid
          else:
             high = mid-1
    return low                 

def compute_membership(X, hatmu, fcm_labels, eta, n_samples, cluster_index):
    hatu = np.zeros(n_samples)
    idx = permutation(X, hatmu)
    val = []
    flag = 0
    eta_arr = np.linspace(0,1,int(1.0/eta)) 
    for x in eta_arr:
        value = binary_search(x, idx, X, fcm_labels, cluster_index)
        val.append(value)
    for i in range(len(val)-1):
        for j in range(val[i],val[i+1],-1):
            hatu[idx[j]] = eta_arr[i]
    z  = fcm_labels[idx[0], cluster_index]
    minimum = np.infty
    for i in eta_arr:
        if z-i < minimum and z-i > 0:
           minimum = i    
    hatu[idx[0]] = minimum  
    #print(val)     
    return hatu               

    
def estimate_together(X, fcm_labels, m, n_samples, eta):
    indices = [i for i in range(n_samples)]
    sampled_indices = random.choices(indices, k=m)
    hatmu = np.empty((d,4))
    hatu = np.empty((n_samples,4))
    for i in range(4):
        hatmu[:,i] = compute_mean(X, sampled_indices, fcm_labels, i)
    for i in range(4):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta, n_samples, i)
    return hatmu, hatu  
    

def compute_first(X, fcm_labels, m, n_samples, eta):
    argmax = -1
    indices = [i for i in range(n_samples)]
    sampled_indices = random.choices(indices, k=m)
    val = np.zeros(4)
    for index in range(4):
        val[index] = np.sum([fcm_labels[i,index]**(alpha) for i in sampled_indices])
    first = np.argmax(val)
    mean_num = np.zeros(d)
    mean_den = 0.0
    for i in sampled_indices:
        mean_num += (fcm_labels[i,first]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,first]**(alpha)
    return mean_num/mean_den, first        


def compute_sets(X, fcm_labels, estimate_index, partition, eta_1, clusters, index):
    membership = compute_membership(X, estimate_index, fcm_labels, eta_1, n_samples, index)
    partition_new = {}
    for i in np.linspace(0,1,int(1.0/eta_1)):
        partition_new[i] = []
    for i in np.linspace(0,1,int(1.0/eta_1)):
        for j in partition[i]:
            #print(i, j, membership[j], clusters, index)
            partition_new[min(1.0,i+membership[j])].append(j)
    return partition_new             

def compute_index(X, fcm_labels, m, n_samples, partition, clusters, eta_1):
    r = (1.0*m)/np.sum([1 for s in np.linspace(0,1,int(1.0/eta_1)) if len(partition[s])>0])
    r = int(r)
    argmax = -1
    mean_num = np.zeros(d)
    mean_den = 0.0
    val = np.zeros(4)
    for index in clusters:
        val[index] = -np.inf
    sampled_indices = {}  
    for s in np.linspace(0,1,int(1.0/eta_1)):
        factor = (1.0*len(partition[s]))/r
        if len(partition[s]) > 0:
            sampled_indices[s] = random.choices(partition[s], k=r)    
            for index in range(4):
                if index not in clusters:
                    val[index] += factor*np.sum([fcm_labels[i,index]**(alpha) for i in sampled_indices[s]])
    first = np.argmax(val)
    for s in np.linspace(0,1,int(1.0/eta_1)):
        factor = (1.0*len(partition[s]))/r
        if len(partition[s]) > 0:
            for i in sampled_indices[s]:
                mean_num += factor*(fcm_labels[i,first]**(alpha))*X[i]
                mean_den +=  factor*fcm_labels[i,first]**(alpha)
    return  mean_num/mean_den, first       

def estimate_sequentially(X, fcm_labels, m, n_samples, eta_1, eta_2):
    indices = [i for i in range(n_samples)]
    hatmu = np.empty((d,4))
    hatu = np.empty((n_samples,4))
    estimate_first, first = compute_first(X, fcm_labels, m, n_samples, eta_1)
    clusters = set()
    clusters.add(first)
    hatmu[:,first] = estimate_first
    partition = {}
    for s in np.linspace(0,1,int(1.0/eta_1)):
        partition[s] = []
    partition[0] = [i for i in range(n_samples)]
    partition = compute_sets(X, fcm_labels, estimate_first, partition, eta_1, clusters, first) 
    for i in range(3):
        print(clusters)
        estimate_index, index = compute_index(X, fcm_labels, m, n_samples, partition, clusters, eta_1)
        clusters.add(index)
        hatmu[:,index] = estimate_index
        partition = compute_sets(X, fcm_labels, estimate_index, partition, eta_1, clusters, index)
    for i in range(4):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta_2, n_samples, i)
    return hatmu, hatu     

def objective_function(X, alpha, hatmu, hatu, n_samples):
    for i in range(n_samples):
        s = 1-np.sum(hatu[i])
        for j in range(len(hatu[i])):
            hatu[i,j] = hatu[i,j]+(0.25*s)
    obj = 0.0 
    for i in range(n_samples):
        for j in range(4):
            obj += (hatu[i,j]**(alpha))*np.dot(X[i]-hatmu[:,j],X[i]-hatmu[:,j])        
    return obj/n_samples

def xie_beni(X, alpha, hatmu, hatu, n_samples):
    for i in range(n_samples):
        s = 1-np.sum(hatu[i])
        for j in range(len(hatu[i])):
            hatu[i,j] = hatu[i,j]+(0.25*s)
    obj = 0.0 
    for i in range(n_samples):
        for j in range(4):
            obj += (hatu[i,j]**(alpha))*np.dot(X[i]-hatmu[:,j],X[i]-hatmu[:,j])        
    minimum = np.infty
    for i in range(4):
        for j in range(4):
            if i != j:
                if minimum > np.dot(hatmu[:,i]-hatmu[:,j],hatmu[:,i]-hatmu[:,j]):
                   minimum = np.dot(hatmu[:,i]-hatmu[:,j],hatmu[:,i]-hatmu[:,j])
    return obj/(n_samples*minimum)





mu_0 = tuple(1000+np.random.rand(10)*multiplier)
mu_1 = tuple(np.random.rand(10)*multiplier)
mu_2 = tuple(np.random.rand(10)*multiplier)
mu_3 = tuple(-1000+np.random.rand(10)*multiplier)


epsilon_1 = np.empty((24,10))
epsilon_2 = np.empty((24,10))
obj1 = np.empty((24,10))
obj2 = np.empty((24,10))
xb1 = np.empty((24,10))
xb2 = np.empty((24,10))
queries = np.empty((24,10))
obj = np.empty(24)
xb = np.empty(24)


for beta in range(1,25):
    print(beta)
    samples_dist = [5000, 5000*beta, 5000*beta, 5000*beta]
    n_samples = sum(samples_dist)
    centers = [mu_0, mu_1, mu_2, mu_3]

    X,_ = make_blobs(n_samples = samples_dist, n_features=10, cluster_std = 20.0, centers=centers, shuffle = False, random_state=42)

    fcm = FCM(n_clusters=4, m=2)
    fcm.fit(X)

    fcm_centers = fcm.centers 
    fcm_labels = fcm.u

    obj[beta-1] = objective_function(X, 2, np.transpose(fcm.centers), fcm.u, n_samples)
    xb[beta-1] = xie_beni(X, 2, np.transpose(fcm.centers), fcm.u, n_samples)


    for y, factor in enumerate([1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]):
        print(factor)
        queries[beta-1][y] = factor*4
        factor2 = int(factor/2.5)
        temp1 = []; temp2 = []
        for qq in range(10):
            print(qq)
            hatmu, hatu = estimate_together(X, fcm_labels, factor, n_samples, 0.1) 
            hatmu2, hatu2 = estimate_sequentially(X, fcm_labels, factor2, n_samples, 0.1, 0.1)
            temp1.append(max([np.dot(hatmu[:,i]-fcm_centers[i],hatmu[:,i]-fcm_centers[i]) for i in range(4)]))
            temp2.append(max([np.dot(hatmu2[:,i]-fcm_centers[i],hatmu2[:,i]-fcm_centers[i]) for i in range(4)]))
        epsilon_1[beta-1][y] = np.mean(temp1)
        epsilon_2[beta-1][y] = np.mean(temp2)
        obj1[beta-1][y] = objective_function(X, 2, hatmu, hatu, n_samples)
        obj2[beta-1][y] = objective_function(X, 2, hatmu2, hatu2, n_samples)
        xb1[beta-1][y] = xie_beni(X, 2, hatmu, hatu, n_samples)
        xb2[beta-1][y] = xie_beni(X, 2, hatmu2, hatu2, n_samples) 

np.save('ep_1_20.txt',epsilon_1)    
np.save('ep_2_20.txt',epsilon_2)
np.save('ob_1_20.txt',obj1)
np.save('ob_2_20.txt',obj2)
np.save('xb_1_20.txt',xb1)
np.save('xb_2_20.txt',xb2)
np.save('ob_20.txt',obj)



linestyles = ['-', '--', '-.', ':']

markers=['bo-','rD-','g*-','b-','c2']

true_beta = [(4.0/(1+3*beta)) for beta in range(1,25)]
plt.plot(true_beta[3:], epsilon_1[:,0][3:], markers[0],markersize=0,label="Procedure 1 (Queries = "+r'$8 \times 10^3$)',linewidth=4, linestyle=linestyles[0])
plt.plot(true_beta[3:], epsilon_2[:,0][3:], markers[1],markersize=0,label="Procedure 4 (Queries = "+r'$8 \times 10^3$)',linewidth=4, linestyle=linestyles[1])
plt.plot(true_beta[3:], epsilon_1[:,2][3:], markers[2],markersize=0,label="Procedure 1 (Queries = "+r'$24 \times 10^3$)',linewidth=4, linestyle=linestyles[2])
plt.plot(true_beta[3:], epsilon_2[:,2][3:], markers[0],markersize=0,label="Procedure 4 (Queries = "+r'$24 \times 10^3$)',linewidth=4, linestyle=linestyles[3])
#plt.plot(true_beta, epsilon_1[:,4], markers[1],markersize=0,label="Proc. 1 (Q = 40)",linewidth=4, linestyle=linestyles[0])
#plt.plot(true_beta, epsilon_2[:,4], markers[2],markersize=0,label="Proc. 4 (Q = 40)",linewidth=4, linestyle=linestyles[1])

plt.legend(fontsize=10)
#plt.legend(bbox_to_anchor=(0., 0.99, 1., .102), loc=3,
#           ncol=3, mode="expand", borderaxespad=0.)
plt.xlabel(r'$\beta$'+' (Minimum membership)',fontsize=10)
plt.ylabel(r'$\epsilon$'+' (Maximum deviation of estimated mean)',fontsize=10)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
plt.ylim([0, 360])
plt.grid()
plt.show()                       
plt.close()
'''

plt.plot(queries[5], epsilon_1[5], markers[0],markersize=0,label="Procedure 1 ("+r'$\beta = 0.25$)',linewidth=4, linestyle=linestyles[0])
plt.plot(queries[5], epsilon_2[5], markers[1],markersize=0,label="Procedure 4 ("+r'$\beta = 0.25$)',linewidth=4, linestyle=linestyles[1])
plt.plot(queries[13], epsilon_1[13], markers[2],markersize=0,label="Procedure 1 ("+r'$\beta = 0.1$)',linewidth=4, linestyle=linestyles[2])
plt.plot(queries[13], epsilon_2[13], markers[3],markersize=0,label="Procedure 4 ("+r'$\beta = 0.1$)',linewidth=4, linestyle=linestyles[3])
#plt.plot(queries[21], epsilon_1[21], markers[1],markersize=0,label="Proc. 1 ("+r'$\beta = 0.0625$)',linewidth=4, linestyle=linestyles[0])
#plt.plot(queries[21], epsilon_2[21], markers[2],markersize=0,label="Proc. 4 ("+r'$\beta = 0.0625$)',linewidth=4, linestyle=linestyles[1])

plt.legend(fontsize=10)
#plt.legend(bbox_to_anchor=(0., 0.99, 1., .102), loc=3,
#           ncol=3, mode="expand", borderaxespad=0.)
plt.xlabel('Query complexity',fontsize=10)
plt.ylabel('$\epsilon$ (Maximum deviation of estimated mean)',fontsize=10)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
plt.ylim([0, 175])
plt.grid()
plt.show()                       
plt.close()


plt.plot(queries[5], xb1[5], markers[0],markersize=0,label="Procedure 1 ("+r'$\beta = 0.25$)',linewidth=4, linestyle=linestyles[0])
plt.plot(queries[5], xb2[5], markers[1],markersize=0,label="Procedure 4 ("+r'$\beta = 0.25$)',linewidth=4, linestyle=linestyles[1])
plt.plot(queries[13], xb1[13], markers[2],markersize=0,label="Procedure 1 ("+r'$\beta = 0.1$)',linewidth=4, linestyle=linestyles[2])
plt.plot(queries[13], xb2[13], markers[3],markersize=0,label="Procedure 4 ("+r'$\beta = 0.1$)',linewidth=4, linestyle=linestyles[3])
#plt.plot(queries[21], xb1[21], markers[1],markersize=0,label="Proc. 1 ("+r'$\beta = 0.0625$)',linewidth=4, linestyle=linestyles[0])
#plt.plot(queries[21], xb2[21], markers[2],markersize=0,label="Proc. 4 ("+r'$\beta = 0.0625$)',linewidth=4, linestyle=linestyles[1])

plt.legend(fontsize=10)
#plt.legend(bbox_to_anchor=(0., 0.99, 1., .102), loc=3,
#           ncol=3, mode="expand", borderaxespad=0.)
plt.xlabel('Query complexity',fontsize=10)
plt.ylabel('Xie-Beni',fontsize=10)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
plt.ylim([0, 0.01])
plt.grid()
plt.show()                       
plt.close()
'''




