import numpy as np
from numpy.core.numeric import flatnonzero
from convex_representation import convex_representation

def MerageIndepSets(beta1,B1,beta2,B2):
    len_B1 = len(np.flatnonzero(B1))
    len_B2 = len(np.flatnonzero(B2))

    B_new_inside = np.zeros(len(B1))
    beta_new_inside = 0
    if len_B1<len_B2:
        tmp_beta1 = beta1
        tmp_B1 = B1
        beta1 = beta2
        B1 = B2
        beta2 = tmp_beta1
        B2 = tmp_B1

    len_B1 = len(np.flatnonzero(B1))
    len_B2 = len(np.flatnonzero(B2))
    # print(len_B1)
    # print(len_B2)
    tmp = beta1/(beta1+beta2)
    if tmp>=0.999:
        tmp = 0.99
    if tmp<=0.0001:
        tmp=0.001

    if len_B1 == len_B2:
        B_new_inside, beta_new_inside = MergeBase(beta1,B1,beta2,B2)
        # print("jinrudengyu")

    if len_B1>len_B2:
        # print("jinrudayu")
        over_number = len_B1-len_B2
        for j in range(len(B1)):
            B1[j]=round(B1[j])
            B2[j]=round(B2[j])

        tmp_B = B1-B2
        i_num =0
        for i in range(len((tmp_B))):
            if tmp_B[i]==1:
                i_num+=1
        tmp_i = 0
        i_set = np.zeros(i_num)
        for i in range(len((tmp_B))):
            if tmp_B[i]==1:
                i_set[tmp_i]=i
                tmp_i+=1
        S_idnex = np.random.choice(i_set,over_number,replace=False)
        for i in range(len(S_idnex)):
            S_idnex[i] = int(round(S_idnex[i]))
        # print(S_idnex)
        for i in range(len(S_idnex)):
            tmp_S_index = int(round(S_idnex[i]))
            B2[tmp_S_index]=1
        B_new_inside, beta_new_inside = MergeBase(beta1,B1,beta2,B2)


        tmp_beta2 = beta2/(beta1+beta2)
        if tmp_beta2>=0.999:
            tmp_beta2 = 0.99
        if tmp_beta2<=0.0001:
            tmp_beta2=0.001

        for i in range(len(S_idnex)):
            probability = np.random.binomial(1,tmp_beta2)  # 0或1
            if probability == 1:
                tmp_S_index_new = int(round(S_idnex[i]))
                B_new_inside[tmp_S_index_new]=0
        

    return B_new_inside, beta_new_inside


def MergeBase(beta1,B1,beta2,B2):
    # len_B1 = len(np.flatnonzero(B1))
    tmp = beta1/(beta1+beta2)
    if tmp>=0.999:
        tmp = 0.99
    if tmp<=0.0001:
        tmp=0.001
    # print("tmp",tmp)
    while (B1==B2).all()==False:

        for j in range(len(B1)):
            B1[j]=round(B1[j])
            B2[j]=round(B2[j])
        tmp_B = B1-B2
        # print(B1)
        # print(B2)
        tmp_num = round(len(np.flatnonzero(tmp_B))/2)
        # print(tmp_num)
        if tmp_num == 0:
            break
        i_set = np.zeros(tmp_num)
        tmp_i =0
        j_set = np.zeros(tmp_num)
        tmp_j =0
        for i in range(len((tmp_B))):
            if tmp_B[i]==-1:
                j_set[tmp_j]=i
                tmp_j+=1
            if tmp_B[i]==1:
                i_set[tmp_i]=i
                tmp_i+=1

        # print(i_set)
        # print(j_set)
        i = int(round(np.random.choice(i_set,1)[0]))

        j = int(round(np.random.choice(j_set,1)[0]))

        probability = np.random.binomial(1,tmp)  # 0或1
        if probability==1:
            B2[j]=0
            B2[i]=1
        if probability==0:
            B1[i]=0
            B1[j]=1  
            
        for j in range(len(B1)):
            B1[j]=round(B1[j])
            B2[j]=round(B2[j])
    return B1,beta1+beta2

def SwapRound(beta,B):
    beta_new=beta[0]
    B_new = B[:,0]
    for k in range(len(beta)-1):
        # print('gangjin')
        # print(B_new, beta_new)
        B_new, beta_new = MerageIndepSets(beta_new,B_new,beta[k+1],B[:,k+1])
        # print(B_new, beta_new)
    return B_new

# k = 3
# x = [0, 0.6, 0, 0, 0.7,0.8,0.3,0.2]
# acc_B_new = np.zeros(8)
# for i in range(10000):
#     beta,B= convex_representation(k,x)
#     B_new = SwapRound(beta,B)
#     acc_B_new = B_new+acc_B_new

# print(acc_B_new/10000)
# B1 = np.array([0,1,1,1,0,0,0])
# B2 = np.array([1,0,0,0,1,0,1])
# B,beta = MergeBase(0.2,B1,0.3,B2)
# print(B)



        