from Libraries import *
#####################################Pruning methods#####################################################################################
def tropical_pruning(model,it,arg):
    
    compressed_model = copy.deepcopy(model)
    A = model.state_dict()['fc1.weight'].clone().cpu().numpy()
    B = model.state_dict()['fc2.weight'].clone().cpu().numpy()
    
    GeneratorDict = get_dictionary_of_generators(model)
    count = 0
    for key in GeneratorDict.keys():
        arg[count] = arg[count]*np.linalg.norm(GeneratorDict[key])
        count += 1

    #Initializations, it could be random, or just initialize with the network weights
    A_k   = A
    B_k   = B
    b_pos = 0.5*(np.abs(B)+B)
    b_neg = 0.5*(np.abs(B)-B)
    cost  = np.zeros(it)

    for iterr in range(it):
#Updating A_k        
        temp  = np.zeros_like(A)
        lamda = np.zeros(B.shape[1])
        beta_square = np.zeros(B.shape[1])

        for j in range(B.shape[0]):
            G1 = GeneratorDict['G' + str(2*j+1)]
            G2 = GeneratorDict['G' + str(2*j+2)]

            b1_plus = np.diag(np.maximum(0,B_k[j,:]))
            b2_plus = np.diag(np.maximum(0,-1*B_k[j,:]))
            cost[iterr] += np.linalg.norm(np.matmul(b1_plus,A_k) - G1)**2
            cost[iterr] += np.linalg.norm(np.matmul(b2_plus,A_k) - G2)**2

            beta_square = beta_square + np.diag(b1_plus)**2 + np.diag(b2_plus)**2

            temp += np.matmul(b1_plus, G1)
            temp += np.matmul(b2_plus, G2)
            lamda = lamda + np.diag(arg[2*j]*b1_plus + arg[2*j+1]*b2_plus)
        
        cost[iterr] += np.sum(np.linalg.norm(np.matmul(np.diag(lamda),A_k),1))#The L_21 norm in the cost  
        

        temp  = np.divide(temp  ,beta_square.reshape(B.shape[1],-1))
        lamda = np.divide(lamda.reshape(B.shape[1],-1) ,beta_square.reshape(B.shape[1],-1))
        A_k   = prox_l21_new(temp,lamda)
        
        A_k[np.isnan(A_k)] = 0

#Updating B_k
        norm_Ak = np.linalg.norm(A_k,axis=1)
        for j in range(B.shape[0]):
            G1    = GeneratorDict['G' + str(2*j+1)].T
            Temp1 = (2*np.diag(np.matmul(A_k,G1)) - arg[2*j]*norm_Ak)/(2*norm_Ak**2)
            
            b_pos[j,:] = np.maximum(Temp1,0)

            G2 = GeneratorDict['G' + str(2*j+2)].T
            Temp2 = (2*np.diag(np.matmul(A_k,G2)) - arg[2*j+1]*norm_Ak)/(2*norm_Ak**2)
            b_neg[j,:] = np.maximum(Temp2,0)

        B_k = b_pos - b_neg
        B_k[np.isnan(B_k)] = 0
        
    compression_percentage = (np.sum(A_k == 0) + np.sum(B_k == 0))/(B_k.size + A_k.size) *100

    idx1 = np.argwhere(A_k == 0)
    idx2 = np.argwhere(B_k == 0)

    A = model.state_dict()['fc1.weight'].clone().cpu().numpy()
    B = model.state_dict()['fc2.weight'].clone().cpu().numpy()

    A[idx1[:,0],idx1[:,1]] = 0
    B[idx2[:,0],idx2[:,1]] = 0
    compressed_model.state_dict()['fc1.weight'][:,:] = torch.from_numpy(A).to(device)
    compressed_model.state_dict()['fc2.weight'][:,:] = torch.from_numpy(B).to(device)
    
    return compressed_model,compression_percentage,cost


##################################################Some utilities ##########################################33
def prox_l21_new(x,lamda): #Proximal of L_2,1 Norm
    eps = sys.float_info.epsilon
    for i in range(len(x)):
        if np.linalg.norm(x[i,:]) != 0:
            x[i,:] = np.maximum(0,1-(lamda[i]/(eps + np.linalg.norm(x[i,:]))))*x[i,:]
    return x
def get_dictionary_of_generators_temp(model,bias=False):
#Getting the generators "Bulky and need to be replaced efficiently"
    B = model.state_dict()['fc2.weight'].clone().cpu().numpy()
    GeneratorDict=dict()
    string1=1
    num_classes = B.shape[0]
    for i in range(num_classes-1):
        GeneratorDict['G'+str(string1)],GeneratorDict['G'+str(string1+1)],GeneratorDict['G'+str(string1+2)],GeneratorDict['G'+str(string1+3)] = get_model_zonotopes_temp(model,[i,i+1],bias)
        string1+=2
    return GeneratorDict
def get_model_zonotopes_temp(model,classes = [0,1],bias=False):
    A = model.state_dict()['fc1.weight'].cpu().numpy()
    if bias:
        a = model.state_dict()['fc1.bias'].cpu().numpy()
        A = np.hstack((A,a.reshape(A.shape[0],1)))
    B = model.state_dict()['fc2.weight'][classes,:].cpu().numpy()
    gen = get_generators_new(A,B)
    G1,G2 = gen[0]
    G3,G4 = gen[1]
    return G1,G2,G3,G4
def get_generators_new(A,B):
    out = []
    for label in range(B.shape[0]):
        b_pos = np.maximum(0,B[label, :])
        b_neg = np.maximum(0,-1*B[label, :])
        positive = A.T * b_pos
        negative = A.T * b_neg
        out.append((positive.T, negative.T))
    return out