import numpy as np 
from fcmeans import FCM 
from sklearn.datasets import make_blobs 
#from sklearn.datasets import fetch_openml
from sklearn.datasets import load_breast_cancer
import os
import ssl
import matplotlib.pyplot as plt

import random 


#if (not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None)):
#    ssl._create_default_https_context = ssl._create_unverified_context
#spam = fetch_openml(name='spambase')

#d = 57         
d = 30
alpha = 2
#multiplier = 5


def compute_mean(X, sampled_indices, fcm_labels, index):
    mean_num = np.zeros(d)
    mean_den = 0.0
    for i in sampled_indices:
        mean_num += (fcm_labels[i,index]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,index]**(alpha)
    return mean_num/mean_den      


def permutation(X, hatmu):
    #print(X[0].shape, hatmu.shape)
    perm = []
    for i in range(n_samples):
        dist = np.dot(X[i]-hatmu,X[i]-hatmu)
        perm.append((dist, i))
    perm = sorted(perm, key=lambda x:x[0])
    idx = [x[1] for x in perm]
    return(idx)


def binary_search(x, idx, X, fcm_labels, cluster_index):
    low = 0; high = n_samples-1
    while low!=high: 
          mid = int((low+high)/2)+((low+high)%2)
          if low > high:
             low = None
             break
          if fcm_labels[idx[mid], cluster_index] >= x:
             low = mid
          else:
             high = mid-1
    return low                 

def compute_membership(X, hatmu, fcm_labels, eta, n_samples, cluster_index):
    hatu = np.zeros(n_samples)
    idx = permutation(X, hatmu)
    val = []
    flag = 0
    eta_arr = np.linspace(0,1,int(1.0/eta)) 
    for x in eta_arr:
        value = binary_search(round(x,3), idx, X, fcm_labels, cluster_index)
        val.append(value)
    for i in range(len(val)-1):
        for j in range(val[i],val[i+1],-1):
            hatu[idx[j]] = eta_arr[i]
    z  = fcm_labels[idx[0], cluster_index]
    minimum = np.infty
    for i in eta_arr:
        if z-i < minimum and z-i > 0:
           minimum = i    
    hatu[idx[0]] = minimum  
    #print(val)     
    return hatu               

    
def estimate_together(X, fcm_labels, m, n_samples, eta, k):
    indices = [i for i in range(n_samples)]
    sampled_indices = random.choices(indices, k=m)
    hatmu = np.empty((d,k))
    hatu = np.empty((n_samples,k))
    for i in range(k):
        hatmu[:,i] = compute_mean(X, sampled_indices, fcm_labels, i)
    for i in range(k):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta, n_samples, i)
    return hatmu, hatu  
    

def compute_first(X, fcm_labels, m, n_samples, eta, k):
    argmax = -1
    indices = [i for i in range(n_samples)]
    sampled_indices = random.choices(indices, k=m)
    val = np.zeros(k)
    for index in range(k):
        val[index] = np.sum([fcm_labels[i,index]**(alpha) for i in sampled_indices])
    first = np.argmax(val)
    mean_num = np.zeros(d)
    mean_den = 0.0
    for i in sampled_indices:
        mean_num += (fcm_labels[i,first]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,first]**(alpha)
    return mean_num/mean_den, first        


def compute_sets(X, fcm_labels, estimate_index, partition, eta_1, clusters, index):
    membership = compute_membership(X, estimate_index, fcm_labels, eta_1, n_samples, index)
    partition_new = {}
    for i in np.linspace(0,1,int(1.0/eta_1)):
        partition_new[round(i,3)] = []
    for i in np.linspace(0,1,int(1.0/eta_1)):
        for j in partition[round(i,3)]:
            #print(i, j, membership[j], clusters, index)
            #print(partition_new.keys())
            partition_new[min(1.0,round(i+membership[j]),3)].append(j)
    return partition_new             

def compute_index(X, fcm_labels, m, n_samples, partition, clusters, eta_1, k):
    r = (1.0*m)/np.sum([1 for s in np.linspace(0,1,int(1.0/eta_1)) if len(partition[round(s,3)])>0])
    r = int(r)
    argmax = -1
    mean_num = np.zeros(d)
    mean_den = 0.0
    val = np.zeros(k)
    for index in clusters:
        val[index] = -np.inf
    sampled_indices = {}  
    for s in np.linspace(0,1,int(1.0/eta_1)):
        factor = (1.0*len(partition[round(s,3)]))/r
        if len(partition[round(s,3)]) > 0:
            sampled_indices[round(s,3)] = random.choices(partition[round(s,3)], k=r)    
            for index in range(k):
                if index not in clusters:
                    val[index] += factor*np.sum([fcm_labels[i,index]**(alpha) for i in sampled_indices[round(s,3)]])
    first = np.argmax(val)
    for s in np.linspace(0,1,int(1.0/eta_1)):
        factor = (1.0*len(partition[round(s,3)]))/r
        if len(partition[round(s,3)]) > 0:
            for i in sampled_indices[round(s,3)]:
                mean_num += factor*(fcm_labels[i,first]**(alpha))*X[i]
                mean_den +=  factor*fcm_labels[i,first]**(alpha)
    return  mean_num/mean_den, first       

def estimate_sequentially(X, fcm_labels, m, n_samples, eta_1, eta_2, k):
    indices = [i for i in range(n_samples)]
    hatmu = np.empty((d,k))
    hatu = np.empty((n_samples,k))
    estimate_first, first = compute_first(X, fcm_labels, m, n_samples, eta_1, k)
    clusters = set()
    clusters.add(first)
    hatmu[:,first] = estimate_first
    partition = {}
    for s in np.linspace(0,1,int(1.0/eta_1)):
        partition[round(s,3)] = []
    partition[0] = [i for i in range(n_samples)]
    partition = compute_sets(X, fcm_labels, estimate_first, partition, eta_1, clusters, first) 
    for i in range(k-1):
        #print(clusters)
        estimate_index, index = compute_index(X, fcm_labels, m, n_samples, partition, clusters, eta_1, k)
        clusters.add(index)
        hatmu[:,index] = estimate_index
        partition = compute_sets(X, fcm_labels, estimate_index, partition, eta_1, clusters, index)
    for i in range(k):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta_2, n_samples, i)
    return hatmu, hatu     

def estimate_sequentially_special(X, fcm_labels, m, r, n_samples, eta, k):
    indices = [i for i in range(n_samples)]
    hatmu = np.empty((d,k))
    hatu = np.empty((n_samples,k))
    estimate_first, first = compute_first(X, fcm_labels, m, n_samples, eta, k)
    clusters = set()
    print(first)
    clusters.add(first)
    hatmu[:,first] = estimate_first
    idx = permutation(X, hatmu[:,first])
    partition = {}
    special1 = []
    special2 = []
    peta_qlast = n_samples-1
    eta_qlast = 1-fcm_labels[idx[n_samples-1],first] 
    special1.append(idx[n_samples-1])
    for q  in range(20):
        peta_qprime = binary_search2(0.5*eta_qlast, idx, X, fcm_labels, first)
        print("C", peta_qlast, peta_qprime)
        if peta_qprime < peta_qlast-10:
            print("A")
            partition[1-fcm_labels[idx[peta_qprime],first]] = [idx[s] for s in range(peta_qprime,peta_qlast)]
            peta_qlast = peta_qprime; eta_qlast = 0.5*eta_qlast
        else:
            print("B", peta_qlast)
            if peta_qlast - 10 > 0:
                special1 = special1+[idx[s] for s in range(peta_qlast-10,peta_qlast)]
                peta_qlast = peta_qlast-10
                eta_qlast = 1- fcm_labels[idx[peta_qlast],first]
            else:
                special1 = special1+[idx[s] for s in range(peta_qlast)]
                peta_qlast = 0
                break

    special2 = [idx[s] for s in range(peta_qlast)]
    if first == 0:
       second = 1
    else:
       second = 0
    hatmu[:,second] = compute_index2(X, fcm_labels, r, n_samples, partition, special1, special2, second)              
    for i in range(k):
        hatu[:,i] = compute_membership(X, hatmu[:,i], fcm_labels, eta, n_samples, i)
    return hatmu, hatu 

def binary_search2(x, idx, X, fcm_labels, cluster_index):
    print(x)
    low = 0; high = n_samples-1
    while low!=high: 
          mid = int((low+high)/2)
          if low > high:
             low = None
             break
          if 1-fcm_labels[idx[mid], cluster_index] >= x:
             high = mid
          else:
             low = mid+1
    return low                 

def compute_index2(X, fcm_labels, r, n_samples, partition, special1, special2, second):
    mean_num = np.zeros(d)
    mean_den = 0.0
    print("A")
    print(len(special1))
    print(len(special2),r)
    for k, v in partition.items():
        print(len(v), r)
        sampled_indices = random.choices(v, k=r)
        factor = (1.0*len(v))/r
        for i in sampled_indices:
            mean_num += factor*(fcm_labels[i,second]**(alpha))*X[i]
            mean_den +=  factor*fcm_labels[i,second]**(alpha)
    
    if len(special2) > 0:  
        sampled_indices = random.choices(special2, k=r)
        factor = (1.0*len(special2))/r
        for i in sampled_indices:
            mean_num += factor*(fcm_labels[i,second]**(alpha))*X[i]
            mean_den +=  factor*fcm_labels[i,second]**(alpha) 
   
    for i in special1:
        #print(fcm_labels[i,second])
        mean_num += (fcm_labels[i,second]**(alpha))*X[i]
        mean_den +=  fcm_labels[i,second]**(alpha)      

    return  mean_num/mean_den 

def objective_function(X, alpha, hatmu, hatu, n_samples, k):
    for i in range(n_samples):
        s = 1-np.sum(hatu[i])
        for j in range(len(hatu[i])):
            hatu[i,j] = hatu[i,j]+((s*1.0))/k
    obj = 0.0 
    for i in range(n_samples):
        for j in range(k):
            obj += (hatu[i,j]**(alpha))*np.dot(X[i]-hatmu[:,j],X[i]-hatmu[:,j])        
    return obj/n_samples

data = load_breast_cancer()
X = data['data']
n_samples = X.shape[0]
           
k=2
acc1 = []
acc2 = []
acc3 = []
queries = []
labels = np.empty((X.shape[0],k))
for i in range(X.shape[0]):
    labels[i,data['target'][i]] = 0.99
    for j in range(k):
        if j != data['target'][i]:
           labels[i,j]=0.01
fcm_labels = labels                        
for y, factor in enumerate([i*10 for i in range(1,41,2)]):
    factor2 = int((factor*2.0)/(k+1))
    queries.append(factor*k)
    temp1 = []
    temp2 = []
    temp3 = []
    for t in range(500):
        print(t)
        hatmu, hatu = estimate_together(X, fcm_labels, factor, n_samples, 0.1, k) 
        hatmu2, hatu2 = estimate_sequentially(X, fcm_labels, factor2, n_samples, 0.1, 0.1, k)
        hatmu3, hatu3 = estimate_sequentially_special(X, fcm_labels, factor2, factor2, n_samples, 0.1, 2)
        count = 0
        for i in range(X.shape[0]):
            if np.argmax(hatu[i]) != data['target'][i]:
                count += 1
        temp1.append((1.0*count)/X.shape[0]) 
        count = 0
        for i in range(X.shape[0]):
            if np.argmax(hatu2[i]) != data['target'][i]:
                count += 1
        temp2.append((1.0*count)/X.shape[0])
        count = 0
        for i in range(X.shape[0]):
            if np.argmax(hatu3[i]) != data['target'][i]:
                count += 1
        temp3.append((1.0*count)/X.shape[0])
    acc1.append(np.mean(temp1))
    acc2.append(np.mean(temp2)) 
    acc3.append(np.mean(temp3)) 



linestyles = ['-', '--', '-.', ':']

markers=['bo-','rD-','g*-','b-','c2']

#plt.plot(2+np.arange(10), epsilon_1[:,0], markers[0],markersize=0,label="Proc. 1 (Q = 8)",linewidth=4, linestyle=linestyles[0])
#plt.plot(2+np.arange(10), epsilon_2[:,0], markers[1],markersize=0,label="Proc. 4 (Q = 8)",linewidth=4, linestyle=linestyles[1])
plt.plot(queries[::2], 100*(1-np.array(acc1[::2])), markers[0],markersize=8,label="Procedure 1",linewidth=1)    
plt.plot(queries[::2], 100*(1-np.array(acc2[::2])), markers[1],markersize=8,label="Procedure 4",linewidth=1)                                                                                                                        
plt.plot(queries[::2], 100*(1-np.array(acc3[::2])), markers[2],markersize=8,label="Procedure 6",linewidth=1)                                                                                                                        

plt.legend(fontsize=10)
#plt.legend(bbox_to_anchor=(0., 0.99, 1., .102), loc=3,
#           ncol=3, mode="expand", borderaxespad=0.)
plt.xlabel(r'Number of queries',fontsize=10)
plt.ylabel(r'Accuracy',fontsize=10)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0),useMathText=True)
plt.ylim([82.5, 86])
plt.grid()
plt.show()                       
plt.close()

           