from Libraries import *
def ClassBlind3Layer(model,percentage,architecture):


    """ This function is an implementation of the compression method "Class Blind" discussed in https://arxiv.org/pdf/1606.09274.pdf
    Returns a compressed model.
    CAUTION: This function works only with models that has 3 fully connected layers in the classifier. However, the extension is trivial for more layers

    Parameters
    ----------
    model : PyTorch Model
        A model that we want to compress
    percentage : float
        A float to specify the percentage of entries to prune from the fully connected layers.
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    newmodel:
        The compressed version of the input model.
    """


    if architecture == 'lenet':
        newmodel = copy.deepcopy(model)
        concat1 = newmodel.fc1.weight.data.clone().cpu().numpy().reshape(-1,1)
        concat2 = newmodel.fc2.weight.data.clone().cpu().numpy().reshape(-1,1)
        concat3 = newmodel.fc3.weight.data.clone().cpu().numpy().reshape(-1,1)
        concat = np.concatenate((concat1,concat2,concat3),axis=0)
        indexofsort = np.argsort(np.absolute(concat),axis=0)
        amount = int(percentage/100 * indexofsort.shape[0])
        concat[indexofsort[0:amount]] = 0 
        concat1 = concat[0:model.fc1.weight.data.shape.numel()].reshape(model.fc1.weight.data.shape[0],model.fc1.weight.data.shape[1])
        concat2 = concat[model.fc1.weight.data.numel(): model.fc1.weight.data.numel()+ model.fc2.weight.data.numel()].reshape(model.fc2.weight.data.shape[0],model.fc2.weight.data.shape[1])
        concat3 = concat[ model.fc1.weight.data.numel()+ model.fc2.weight.data.numel(): model.fc1.weight.data.numel()+model.fc2.weight.data.numel()+ model.fc3.weight.data.numel()].reshape(model.fc3.weight.data.shape[0],model.fc3.weight.data.shape[1])
        newmodel.fc1.weight.data=torch.from_numpy(concat1)
        newmodel.fc2.weight.data=torch.from_numpy(concat2)
        newmodel.fc3.weight.data=torch.from_numpy(concat3)
    elif architecture == 'vgg16':
        newmodel = copy.deepcopy(model)
        concat1 = newmodel.classifier[0].weight.data.clone().cpu().numpy().reshape(-1,1)
        concat2 = newmodel.classifier[3].weight.data.clone().cpu().numpy().reshape(-1,1)
        concat3 = newmodel.classifier[6].weight.data.clone().cpu().numpy().reshape(-1,1)
        concat = np.concatenate((concat1,concat2,concat3),axis=0)
        indexofsort = np.argsort(np.absolute(concat),axis=0)
        amount = int(percentage/100 * indexofsort.shape[0])
        concat[indexofsort[0:amount]] = 0 
        concat1 = concat[0:model.classifier[0].weight.data.shape.numel()].reshape(model.classifier[0].weight.data.shape[0],model.classifier[0].weight.data.shape[1])
        concat2 = concat[model.classifier[0].weight.data.numel(): model.classifier[0].weight.data.numel()+ model.classifier[3].weight.data.numel()].reshape(model.classifier[3].weight.data.shape[0],model.classifier[3].weight.data.shape[1])
        concat3 = concat[ model.classifier[0].weight.data.numel()+ model.classifier[3].weight.data.numel(): model.classifier[0].weight.data.numel()+model.classifier[3].weight.data.numel()+ model.classifier[6].weight.data.numel()].reshape(model.classifier[6].weight.data.shape[0],model.classifier[6].weight.data.shape[1])
        newmodel.classifier[0].weight.data=torch.from_numpy(concat1)
        newmodel.classifier[3].weight.data=torch.from_numpy(concat2)
        newmodel.classifier[6].weight.data=torch.from_numpy(concat3)
    elif architecture == 'alexnet':
        newmodel = copy.deepcopy(model)
        concat1 = newmodel.fc[0].weight.data.clone().cpu().numpy().reshape(-1,1)
        concat2 = newmodel.fc[3].weight.data.clone().cpu().numpy().reshape(-1,1)
        concat3 = newmodel.fc[6].weight.data.clone().cpu().numpy().reshape(-1,1)
        concat = np.concatenate((concat1,concat2,concat3),axis=0)
        indexofsort = np.argsort(np.absolute(concat),axis=0)
        amount = int(percentage/100 * indexofsort.shape[0])
        concat[indexofsort[0:amount]] = 0 
        concat1 = concat[0:model.fc[0].weight.data.shape.numel()].reshape(model.fc[0].weight.data.shape[0],model.fc[0].weight.data.shape[1])
        concat2 = concat[model.fc[0].weight.data.numel(): model.fc[0].weight.data.numel()+ model.fc[3].weight.data.numel()].reshape(model.fc[3].weight.data.shape[0],model.fc[3].weight.data.shape[1])
        concat3 = concat[ model.fc[0].weight.data.numel()+ model.fc[3].weight.data.numel(): model.fc[0].weight.data.numel()+model.fc[3].weight.data.numel()+ model.fc[6].weight.data.numel()].reshape(model.fc[6].weight.data.shape[0],model.fc[6].weight.data.shape[1])
        newmodel.fc[0].weight.data=torch.from_numpy(concat1)
        newmodel.fc[3].weight.data=torch.from_numpy(concat2)
        newmodel.fc[6].weight.data=torch.from_numpy(concat3)

    return newmodel

def ClassUniform3Layer(model,percentage,architecture):


    """ This function is an implementation of the compression method "Class Uniform" discussed in https://arxiv.org/pdf/1606.09274.pdf
    Returns a compressed model.
    CAUTION: This function works only with models that has 3 fully connected layers in the classifier. However, the extension is trivial for more layers

    Parameters
    ----------
    model : PyTorch Model
        A model that we want to compress
    percentage : float
        A float to specify the percentage of entries to prune from the fully connected layers.
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    newmodel:
        The compressed version of the input model.
    """
    if architecture == 'lenet':
        newmodel = copy.deepcopy(model)   
        concat1 = newmodel.fc1.weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort1 = np.argsort(np.absolute(concat1),axis=0)
        amount1 = int(percentage/100 * indexofsort1.shape[0])
        concat1[indexofsort1[0:amount1]] = 0 
        concat1 = concat1.reshape(model.fc1.weight.data.shape[0],model.fc1.weight.data.shape[1])
        newmodel.fc1.weight.data=torch.from_numpy(concat1)   
        concat2 = newmodel.fc2.weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort2 = np.argsort(np.absolute(concat2),axis=0)
        amount2 = int(percentage/100 * indexofsort2.shape[0])
        concat2[indexofsort2[0:amount2]] = 0 
        concat2 = concat2.reshape(model.fc2.weight.data.shape[0],model.fc2.weight.data.shape[1])
        newmodel.fc2.weight.data=torch.from_numpy(concat2)
        concat3 = newmodel.fc3.weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort3 = np.argsort(np.absolute(concat3),axis=0)
        amount3 = int(percentage/100 * indexofsort3.shape[0])
        concat3[indexofsort3[0:amount3]] = 0 
        concat3 = concat3.reshape(model.fc3.weight.data.shape[0],model.fc3.weight.data.shape[1])
        newmodel.fc3.weight.data=torch.from_numpy(concat3)
    elif architecture == 'vgg16':
        newmodel = copy.deepcopy(model)   
        concat1 = newmodel.classifier[0].weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort1 = np.argsort(np.absolute(concat1),axis=0)
        amount1 = int(percentage/100 * indexofsort1.shape[0])
        concat1[indexofsort1[0:amount1]] = 0 
        concat1 = concat1.reshape(model.classifier[0].weight.data.shape[0],model.classifier[0].weight.data.shape[1])
        newmodel.classifier[0].weight.data=torch.from_numpy(concat1)   
        concat2 = newmodel.classifier[3].weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort2 = np.argsort(np.absolute(concat2),axis=0)
        amount2 = int(percentage/100 * indexofsort2.shape[0])
        concat2[indexofsort2[0:amount2]] = 0 
        concat2 = concat2.reshape(model.classifier[3].weight.data.shape[0],model.classifier[3].weight.data.shape[1])
        newmodel.classifier[3].weight.data=torch.from_numpy(concat2)
        concat3 = newmodel.classifier[6].weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort3 = np.argsort(np.absolute(concat3),axis=0)
        amount3 = int(percentage/100 * indexofsort3.shape[0])
        concat3[indexofsort3[0:amount3]] = 0 
        concat3 = concat3.reshape(model.classifier[6].weight.data.shape[0],model.classifier[6].weight.data.shape[1])
        newmodel.classifier[6].weight.data=torch.from_numpy(concat3)  
    elif architecture == 'alexnet':
        newmodel = copy.deepcopy(model)   
        concat1 = newmodel.fc[0].weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort1 = np.argsort(np.absolute(concat1),axis=0)
        amount1 = int(percentage/100 * indexofsort1.shape[0])
        concat1[indexofsort1[0:amount1]] = 0 
        concat1 = concat1.reshape(model.fc[0].weight.data.shape[0],model.fc[0].weight.data.shape[1])
        newmodel.fc[0].weight.data=torch.from_numpy(concat1)   
        concat2 = newmodel.fc[3].weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort2 = np.argsort(np.absolute(concat2),axis=0)
        amount2 = int(percentage/100 * indexofsort2.shape[0])
        concat2[indexofsort2[0:amount2]] = 0 
        concat2 = concat2.reshape(model.fc[3].weight.data.shape[0],model.fc[3].weight.data.shape[1])
        newmodel.fc[3].weight.data=torch.from_numpy(concat2)
        concat3 = newmodel.fc[6].weight.data.clone().cpu().numpy().reshape(-1,1)
        indexofsort3 = np.argsort(np.absolute(concat3),axis=0)
        amount3 = int(percentage/100 * indexofsort3.shape[0])
        concat3[indexofsort3[0:amount3]] = 0 
        concat3 = concat3.reshape(model.fc[6].weight.data.shape[0],model.fc[6].weight.data.shape[1])
        newmodel.fc[6].weight.data=torch.from_numpy(concat3)
    return newmodel

def ClassDist3Layer(model,lamda,architecture):

    """ This function is an implementation of the compression method "Class Blind" discussed in https://arxiv.org/pdf/1606.09274.pdf
    Returns a compressed model.
    CAUTION: This function works only with models that has 3 fully connected layers in the classifier. However, the extension is trivial for more layers

    Parameters
    ----------
    model : PyTorch Model
        A model that we want to compress
    lamda : float
        A threshold value used to prune the models.
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    newmodel:
        The compressed version of the input model.
    """
    if architecture == 'lenet':
        newmodel = copy.deepcopy(model) 
        concat1 = newmodel.fc1.weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold1 = np.std(concat1) * lamda
        removeIndex1 = np.argwhere(np.abs(concat1)<= threshold1)
        concat1[removeIndex1] = 0 
        concat1 = concat1.reshape(model.fc1.weight.data.shape[0],model.fc1.weight.data.shape[1])
        newmodel.fc1.weight.data=torch.from_numpy(concat1)
        concat2 = newmodel.fc2.weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold2 = np.std(concat2) * lamda
        removeIndex2 = np.argwhere(np.abs(concat2)<= threshold2)
        concat2[removeIndex2] = 0 
        concat2 = concat2.reshape(model.fc2.weight.data.shape[0],model.fc2.weight.data.shape[1])
        newmodel.fc2.weight.data=torch.from_numpy(concat2)
        concat3 = newmodel.fc3.weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold3 = np.std(concat3) * lamda
        removeIndex3 = np.argwhere(np.abs(concat3)<= threshold3)
        concat3[removeIndex3] = 0 
        concat3 = concat3.reshape(model.fc3.weight.data.shape[0],model.fc3.weight.data.shape[1])
        newmodel.fc3.weight.data=torch.from_numpy(concat3)
    elif architecture == 'vgg16':
        newmodel = copy.deepcopy(model) 
        concat1 = newmodel.classifier[0].weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold1 = np.std(concat1) * lamda
        removeIndex1 = np.argwhere(np.abs(concat1)<= threshold1)
        concat1[removeIndex1] = 0 
        concat1 = concat1.reshape(model.classifier[0].weight.data.shape[0],model.classifier[0].weight.data.shape[1])
        newmodel.classifier[0].weight.data=torch.from_numpy(concat1)
        concat2 = newmodel.classifier[3].weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold2 = np.std(concat2) * lamda
        removeIndex2 = np.argwhere(np.abs(concat2)<= threshold2)
        concat2[removeIndex2] = 0 
        concat2 = concat2.reshape(model.classifier[3].weight.data.shape[0],model.classifier[3].weight.data.shape[1])
        newmodel.classifier[3].weight.data=torch.from_numpy(concat2)
        concat3 = newmodel.classifier[6].weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold3 = np.std(concat3) * lamda
        removeIndex3 = np.argwhere(np.abs(concat3)<= threshold3)
        concat3[removeIndex3] = 0 
        concat3 = concat3.reshape(model.classifier[6].weight.data.shape[0],model.classifier[6].weight.data.shape[1])
        newmodel.classifier[6].weight.data=torch.from_numpy(concat3)
    elif architecture == 'alexnet':
        newmodel = copy.deepcopy(model) 
        concat1 = newmodel.fc[0].weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold1 = np.std(concat1) * lamda
        removeIndex1 = np.argwhere(np.abs(concat1)<= threshold1)
        concat1[removeIndex1] = 0 
        concat1 = concat1.reshape(model.fc[0].weight.data.shape[0],model.fc[0].weight.data.shape[1])
        newmodel.fc[0].weight.data=torch.from_numpy(concat1)
        concat2 = newmodel.fc[3].weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold2 = np.std(concat2) * lamda
        removeIndex2 = np.argwhere(np.abs(concat2)<= threshold2)
        concat2[removeIndex2] = 0 
        concat2 = concat2.reshape(model.fc[3].weight.data.shape[0],model.fc[3].weight.data.shape[1])
        newmodel.fc[3].weight.data=torch.from_numpy(concat2)
        concat3 = newmodel.fc[6].weight.data.clone().cpu().numpy().reshape(-1,1)
        threshold3 = np.std(concat3) * lamda
        removeIndex3 = np.argwhere(np.abs(concat3)<= threshold3)
        concat3[removeIndex3] = 0 
        concat3 = concat3.reshape(model.fc[6].weight.data.shape[0],model.fc[6].weight.data.shape[1])
        newmodel.fc[6].weight.data=torch.from_numpy(concat3)
    return newmodel

