from Libraries import *
from TG_pruning import *
from CB_CD_CU import *
from Models_training import *



device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
criterion = nn.CrossEntropyLoss().to(device)
workers = 8 if device == 'cuda' else 0
use_gpu = True if device == 'cuda' else False






def compress_tg(net,train_loader,test_loader,criterion, category,retrain,architecture):
    """ This function is an implementation of the compression method derived using Tropical Geometry.
        Returns a compressed model, accuracy of compressed model and percentage of compression corresponding to each accuracy.

    Parameters
    ----------
    net : PyTorch Model
        A model that we want to compress
    train_loader : PyTorch Dataloader
        A PyTorch trainset dataloader
    test_loader : PyTorch Dataloader
        A PyTorch testset dataloader
    criterion : PyTorch Loss Criterion
        The loss criterion used
    category : str
        A string to specify the retraining method Choices: fcbias, allbias, all
    retrain: bool
        A boolean to specify whether or not we want to retrain the model after compression.
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    target_network:
        The compressed version of the input model.
    AccList:
        The list of accuracies of the compressed model.
    PerList:
        The list of compression rates corresponding to the accuracies.
    """
    net=net.to(device)
    biases = True
    if architecture == 'alexnet':
        lin1_inp_size = net.fc[0].weight.data.shape[1]
        lin1_out_size = net.fc[0].weight.data.shape[0]
        lin2_inp_size = net.fc[3].weight.data.shape[1]
        lin2_out_size = net.fc[3].weight.data.shape[0]
        lin3_inp_size = net.fc[6].weight.data.shape[1]
        lin3_out_size = net.fc[6].weight.data.shape[0]
        model = MultiNet(lin2_inp_size,lin2_out_size,lin1_inp_size,lin3_out_size,biases)
        model.state_dict()['fc1.weight'][:,:] = net.fc[0].weight.data
        model.state_dict()['fc1.bias'][:] = net.fc[0].bias.data
        model.state_dict()['fc2.weight'][:,:] = net.fc[3].weight.data
        model.state_dict()['fc2.bias'][:] = net.fc[3].bias.data
        model.state_dict()['fc3.weight'][:,:] = net.fc[6].weight.data
        model.state_dict()['fc3.bias'][:] = net.fc[6].bias.data
        target_network = copy.deepcopy(net).to(device)        
        modelCopy = copy.deepcopy(model)
        CompressionRate = 0
        AccList=[]
        PerList=[]
        percentage=0
        for i in range(10000):
            if CompressionRate>98:
                percentage+=0.2
            percentage += 0.001
            iterations = 1
            print(percentage)
            modeltemp = Net(lin1_out_size,lin1_inp_size,lin2_out_size,biases)
            modeltemp.fc1.weight.data= modelCopy.fc1.weight.data
            modeltemp.fc2.weight.data= modelCopy.fc2.weight.data
            lamda = percentage*np.ones(2*lin2_out_size)
            modeltemp,compression_percentage,cost = tropical_pruning(modeltemp,iterations,lamda)
            modelCopy.fc1.weight.data= modeltemp.fc1.weight.data
            modelCopy.fc2.weight.data= modeltemp.fc2.weight.data

            modeltemp = Net(lin3_inp_size,lin2_out_size,lin3_out_size,biases)
            modeltemp.fc1.weight.data= modelCopy.fc2.weight.data
            modeltemp.fc2.weight.data= modelCopy.fc3.weight.data
            lamda = percentage*np.ones(2*lin3_out_size)
            modeltemp,compression_percentage,cost = tropical_pruning(modeltemp,iterations,lamda)
            modelCopy.fc2.weight.data= modeltemp.fc1.weight.data
            modelCopy.fc3.weight.data= modeltemp.fc2.weight.data

            SumZero = np.sum(modelCopy.fc1.weight.data.cpu().numpy() == 0 ) +np.sum(modelCopy.fc2.weight.data.cpu().numpy() == 0 ) +np.sum(modelCopy.fc3.weight.data.cpu().numpy() == 0 )
            Total = modelCopy.fc1.weight.data.shape[0] * modelCopy.fc1.weight.data.shape[1] +  modelCopy.fc2.weight.data.shape[0] * modelCopy.fc2.weight.data.shape[1] +  modelCopy.fc3.weight.data.shape[0] * modelCopy.fc3.weight.data.shape[1] 
            CompressionRate = SumZero/Total *100
            ###############################
            target_network.fc[0].weight.data = modelCopy.fc1.weight.data
            target_network.fc[3].weight.data = modelCopy.fc2.weight.data
            target_network.fc[6].weight.data = modelCopy.fc3.weight.data
            if retrain:
                target_network = train_net(target_network.to(device),train_loader,test_loader,criterion,category,1e-4,5)
            acc = validate(target_network.to(device),test_loader) 
            AccList.append(acc)
            PerList.append(SumZero/Total *100)
            print('Successful','\n Accuracy = ',acc,'\n Compressing rate = ',SumZero/Total *100 )
            if (SumZero/Total *100) ==100 :
                break
    if architecture == 'lenet':
        lin1_inp_size = net.fc1.weight.data.shape[1]
        lin1_out_size = net.fc1.weight.data.shape[0]
        lin2_inp_size = net.fc2.weight.data.shape[1]
        lin2_out_size = net.fc2.weight.data.shape[0]
        lin3_inp_size = net.fc3.weight.data.shape[1]
        lin3_out_size = net.fc3.weight.data.shape[0]
        model = MultiNet(lin2_inp_size,lin2_out_size,lin1_inp_size,lin3_out_size,biases)
        model.state_dict()['fc1.weight'][:,:] = net.fc1.weight.data
        model.state_dict()['fc1.bias'][:] = net.fc1.bias.data
        model.state_dict()['fc2.weight'][:,:] = net.fc2.weight.data
        model.state_dict()['fc2.bias'][:] = net.fc2.bias.data
        model.state_dict()['fc3.weight'][:,:] = net.fc3.weight.data
        model.state_dict()['fc3.bias'][:] = net.fc3.bias.data
        target_network = copy.deepcopy(net).to(device)        
        modelCopy = copy.deepcopy(model)
        CompressionRate = 0
        AccList=[]
        PerList=[]
        percentage=0
        for i in range(10000):
            if CompressionRate>98:
                percentage+=0.2
            percentage += 0.001
            iterations = 1
            print(percentage)
            modeltemp = Net(lin1_out_size,lin1_inp_size,lin2_out_size,biases)
            modeltemp.fc1.weight.data= modelCopy.fc1.weight.data
            modeltemp.fc2.weight.data= modelCopy.fc2.weight.data
            lamda = percentage*np.ones(2*lin2_out_size)
            modeltemp,compression_percentage,cost = tropical_pruning(modeltemp,iterations,lamda)
            modelCopy.fc1.weight.data= modeltemp.fc1.weight.data
            modelCopy.fc2.weight.data= modeltemp.fc2.weight.data

            modeltemp = Net(lin3_inp_size,lin2_out_size,lin3_out_size,biases)
            modeltemp.fc1.weight.data= modelCopy.fc2.weight.data
            modeltemp.fc2.weight.data= modelCopy.fc3.weight.data
            lamda = percentage*np.ones(2*lin3_out_size)
            modeltemp,compression_percentage,cost = tropical_pruning(modeltemp,iterations,lamda)
            modelCopy.fc2.weight.data= modeltemp.fc1.weight.data
            modelCopy.fc3.weight.data= modeltemp.fc2.weight.data

            SumZero = np.sum(modelCopy.fc1.weight.data.cpu().numpy() == 0 ) +np.sum(modelCopy.fc2.weight.data.cpu().numpy() == 0 ) +np.sum(modelCopy.fc3.weight.data.cpu().numpy() == 0 )
            Total = modelCopy.fc1.weight.data.shape[0] * modelCopy.fc1.weight.data.shape[1] +  modelCopy.fc2.weight.data.shape[0] * modelCopy.fc2.weight.data.shape[1] +  modelCopy.fc3.weight.data.shape[0] * modelCopy.fc3.weight.data.shape[1] 
            CompressionRate = SumZero/Total *100
            ###############################
            target_network.fc1.weight.data = modelCopy.fc1.weight.data
            target_network.fc2.weight.data = modelCopy.fc2.weight.data
            target_network.fc3.weight.data = modelCopy.fc3.weight.data
            if retrain:
                target_network = train_net(target_network.to(device),train_loader,test_loader,criterion,category,1e-4,5)
            acc = validate(target_network.to(device),test_loader) 
            AccList.append(acc)
            PerList.append(SumZero/Total *100)
            print('Successful','\n Accuracy = ',acc,'\n Compressing rate = ',SumZero/Total *100 )
            if (SumZero/Total *100) ==100 :
                break
    if architecture == 'vgg16':
        lin1_inp_size = net.classifier[0].weight.data.shape[1]
        lin1_out_size = net.classifier[0].weight.data.shape[0]
        lin2_inp_size = net.classifier[3].weight.data.shape[1]
        lin2_out_size = net.classifier[3].weight.data.shape[0]
        lin3_inp_size = net.classifier[6].weight.data.shape[1]
        lin3_out_size = net.classifier[6].weight.data.shape[0]
        model = MultiNet(lin2_inp_size,lin2_out_size,lin1_inp_size,lin3_out_size,biases)
        model.state_dict()['fc1.weight'][:,:] = net.classifier[0].weight.data
        model.state_dict()['fc1.bias'][:] = net.classifier[0].bias.data
        model.state_dict()['fc2.weight'][:,:] = net.classifier[3].weight.data
        model.state_dict()['fc2.bias'][:] = net.classifier[3].bias.data
        model.state_dict()['fc3.weight'][:,:] = net.classifier[6].weight.data
        model.state_dict()['fc3.bias'][:] = net.classifier[6].bias.data
        target_network = copy.deepcopy(net).to(device)        
        modelCopy = copy.deepcopy(model)
        CompressionRate = 0
        AccList=[]
        PerList=[]
        percentage=0
        for i in range(10000):
            if CompressionRate>98:
                percentage+=0.2
            percentage += 0.001
            iterations = 1
            print(percentage)
            modeltemp = Net(lin1_out_size,lin1_inp_size,lin2_out_size,biases)
            modeltemp.fc1.weight.data= modelCopy.fc1.weight.data
            modeltemp.fc2.weight.data= modelCopy.fc2.weight.data
            lamda = percentage*np.ones(2*lin2_out_size)
            modeltemp,compression_percentage,cost = tropical_pruning(modeltemp,iterations,lamda)
            modelCopy.fc1.weight.data= modeltemp.fc1.weight.data
            modelCopy.fc2.weight.data= modeltemp.fc2.weight.data

            modeltemp = Net(lin3_inp_size,lin2_out_size,lin3_out_size,biases)
            modeltemp.fc1.weight.data= modelCopy.fc2.weight.data
            modeltemp.fc2.weight.data= modelCopy.fc3.weight.data
            lamda = percentage*np.ones(2*lin3_out_size)
            modeltemp,compression_percentage,cost = tropical_pruning(modeltemp,iterations,lamda)
            modelCopy.fc2.weight.data= modeltemp.fc1.weight.data
            modelCopy.fc3.weight.data= modeltemp.fc2.weight.data

            SumZero = np.sum(modelCopy.fc1.weight.data.cpu().numpy() == 0 ) +np.sum(modelCopy.fc2.weight.data.cpu().numpy() == 0 ) +np.sum(modelCopy.fc3.weight.data.cpu().numpy() == 0 )
            Total = modelCopy.fc1.weight.data.shape[0] * modelCopy.fc1.weight.data.shape[1] +  modelCopy.fc2.weight.data.shape[0] * modelCopy.fc2.weight.data.shape[1] +  modelCopy.fc3.weight.data.shape[0] * modelCopy.fc3.weight.data.shape[1] 
            CompressionRate = SumZero/Total *100
            ###############################
            target_network.classifier[0].weight.data = modelCopy.fc1.weight.data
            target_network.classifier[3].weight.data = modelCopy.fc2.weight.data
            target_network.classifier[6].weight.data = modelCopy.fc3.weight.data
            if retrain:
                target_network = train_net(target_network.to(device),train_loader,test_loader,criterion,category,1e-4,5,architecture=architecture)
            acc = validate(target_network.to(device),test_loader) 
            AccList.append(acc)
            PerList.append(SumZero/Total *100)
            print('Successful','\n Accuracy = ',acc,'\n Compressing rate = ',SumZero/Total *100 )
            if (SumZero/Total *100) ==100 :
                break

    return target_network,AccList,PerList

def compressclassdist (tempModel, train_loader, test_loader, category,architecture, retrain=False):
    """
    This function is the experimental method for the compression method "Class Blind" discussed in https://arxiv.org/pdf/1606.09274.pdf
    Returns a compressed model, accuracy of compressed model and percentage of compression corresponding to each accuracy.

    Parameters
    ----------
    net : PyTorch Model
        A model that we want to compress
    train_loader : PyTorch Dataloader
        A PyTorch trainset dataloader
    test_loader : PyTorch Dataloader
        A PyTorch testset dataloader
    category : str
        A string to specify the retraining method
    retrain: bool
        A boolean to specify whether or not we want to retrain the model after compression.
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    target_network:
        The compressed version of the input model.
    AccList:
        The list of accuracies of the compressed model.
    PerList:
        The list of compression rates corresponding to the accuracies.
    """
    AccList=[]
    PerList=[] 
    for percomp in np.linspace(0,5,50):
        modelCopy = deepcopy(tempModel)   
        newmodel = ClassDist3Layer(modelCopy,percomp,architecture)
        EPOCH=1
        if retrain:
            newmodel = train_net(newmodel.to(device),train_loader,test_loader,criterion,category,1e-4,5,architecture=architecture)
        acc = validate(newmodel.to(device),test_loader) 
        print(acc)
        if architecture == 'alexnet':
            SumZero = np.sum(newmodel.fc[0].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc[3].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc[6].weight.data.cpu().numpy() == 0 )
            Total = newmodel.fc[0].weight.data.shape[0] * newmodel.fc[0].weight.data.shape[1] +  newmodel.fc[3].weight.data.shape[0] * newmodel.fc[3].weight.data.shape[1] +  newmodel.fc[6].weight.data.shape[0] * newmodel.fc[6].weight.data.shape[1] 
        elif architecture == 'vgg16':
            SumZero = np.sum(newmodel.classifier[0].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.classifier[3].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.classifier[6].weight.data.cpu().numpy() == 0 )
            Total = newmodel.classifier[0].weight.data.shape[0] * newmodel.classifier[0].weight.data.shape[1] +  newmodel.classifier[3].weight.data.shape[0] * newmodel.classifier[3].weight.data.shape[1] +  newmodel.classifier[6].weight.data.shape[0] * newmodel.classifier[6].weight.data.shape[1] 
        elif architecture == 'lenet':
            SumZero = np.sum(newmodel.fc1.weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc2.weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc3.weight.data.cpu().numpy() == 0 )
            Total = newmodel.fc1.weight.data.shape[0] * newmodel.fc1.weight.data.shape[1] +  newmodel.fc2.weight.data.shape[0] * newmodel.fc2.weight.data.shape[1] +  newmodel.fc3.weight.data.shape[0] * newmodel.fc3.weight.data.shape[1] 
        AccList.append(acc)
        PerList.append(SumZero/Total *100)
    return newmodel, AccList,PerList

def compressclassunif (tempModel, train_loader, test_loader,category,architecture, retrain=False):
    """ 
    This function is the experimental method for the compression method "Class Blind" discussed in https://arxiv.org/pdf/1606.09274.pdf
    Returns a compressed model, accuracy of compressed model and percentage of compression corresponding to each accuracy.

    Parameters
    ----------
    net : PyTorch Model
        A model that we want to compress
    train_loader : PyTorch Dataloader
        A PyTorch trainset dataloader
    test_loader : PyTorch Dataloader
        A PyTorch testset dataloader
    category : str
        A string to specify the retraining method
    retrain: bool
        A boolean to specify whether or not we want to retrain the model after compression.
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    target_network:
        The compressed version of the input model.
    AccList:
        The list of accuracies of the compressed model.
    PerList:
        The list of compression rates corresponding to the accuracies.
    """    
    AccList=[]
    PerList=[]
    for percomp in range(0,101):
        modelCopy = deepcopy(tempModel)
        newmodel = ClassUniform3Layer(modelCopy,percomp,architecture)
        EPOCH=1
        if retrain:
            newmodel = train_net(newmodel.to(device),train_loader,test_loader,criterion,category,1e-4,5,architecture=architecture)
        acc = validate(newmodel.to(device),test_loader) 
        print(acc)

        if architecture == 'alexnet':
            SumZero = np.sum(newmodel.fc[0].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc[3].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc[6].weight.data.cpu().numpy() == 0 )
            Total = newmodel.fc[0].weight.data.shape[0] * newmodel.fc[0].weight.data.shape[1] +  newmodel.fc[3].weight.data.shape[0] * newmodel.fc[3].weight.data.shape[1] +  newmodel.fc[6].weight.data.shape[0] * newmodel.fc[6].weight.data.shape[1] 
        elif architecture == 'vgg16':
            SumZero = np.sum(newmodel.classifier[0].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.classifier[3].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.classifier[6].weight.data.cpu().numpy() == 0 )
            Total = newmodel.classifier[0].weight.data.shape[0] * newmodel.classifier[0].weight.data.shape[1] +  newmodel.classifier[3].weight.data.shape[0] * newmodel.classifier[3].weight.data.shape[1] +  newmodel.classifier[6].weight.data.shape[0] * newmodel.classifier[6].weight.data.shape[1] 
        elif architecture == 'lenet':
            SumZero = np.sum(newmodel.fc1.weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc2.weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc3.weight.data.cpu().numpy() == 0 )
            Total = newmodel.fc1.weight.data.shape[0] * newmodel.fc1.weight.data.shape[1] +  newmodel.fc2.weight.data.shape[0] * newmodel.fc2.weight.data.shape[1] +  newmodel.fc3.weight.data.shape[0] * newmodel.fc3.weight.data.shape[1] 
        
        AccList.append(acc)
        PerList.append(SumZero/Total *100)
    return newmodel, AccList,PerList

def compressclassblind (tempModel, train_loader, test_loader, category,architecture, retrain=False):
    """ This function is the experimental method for the compression method "Class Blind" discussed in https://arxiv.org/pdf/1606.09274.pdf
    Returns a compressed model, accuracy of compressed model and percentage of compression corresponding to each accuracy.

    Parameters
    ----------
    net : PyTorch Model
        A model that we want to compress
    train_loader : PyTorch Dataloader
        A PyTorch trainset dataloader
    test_loader : PyTorch Dataloader
        A PyTorch testset dataloader
    category : str
        A string to specify the retraining method
    retrain: bool
        A boolean to specify whether or not we want to retrain the model after compression.
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    target_network:
        The compressed version of the input model.
    AccList:
        The list of accuracies of the compressed model.
    PerList:
        The list of compression rates corresponding to the accuracies.
    """
    AccList=[]
    PerList=[]    
    for percomp in range(0,101):
        modelCopy = deepcopy(tempModel)
        newmodel = ClassBlind3Layer(modelCopy,percomp,architecture)
        EPOCH=1
        if retrain:
            newmodel = train_net(newmodel.to(device),train_loader,test_loader,criterion,category,1e-4,5,architecture=architecture)
        acc = validate(newmodel.to(device),test_loader ) 
        print(acc)
        if architecture == 'alexnet':
            SumZero = np.sum(newmodel.fc[0].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc[3].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc[6].weight.data.cpu().numpy() == 0 )
            Total = newmodel.fc[0].weight.data.shape[0] * newmodel.fc[0].weight.data.shape[1] +  newmodel.fc[3].weight.data.shape[0] * newmodel.fc[3].weight.data.shape[1] +  newmodel.fc[6].weight.data.shape[0] * newmodel.fc[6].weight.data.shape[1] 
        elif architecture == 'vgg16':
            SumZero = np.sum(newmodel.classifier[0].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.classifier[3].weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.classifier[6].weight.data.cpu().numpy() == 0 )
            Total = newmodel.classifier[0].weight.data.shape[0] * newmodel.classifier[0].weight.data.shape[1] +  newmodel.classifier[3].weight.data.shape[0] * newmodel.classifier[3].weight.data.shape[1] +  newmodel.classifier[6].weight.data.shape[0] * newmodel.classifier[6].weight.data.shape[1] 
        elif architecture == 'lenet':
            SumZero = np.sum(newmodel.fc1.weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc2.weight.data.cpu().numpy() == 0 ) +np.sum(newmodel.fc3.weight.data.cpu().numpy() == 0 )
            Total = newmodel.fc1.weight.data.shape[0] * newmodel.fc1.weight.data.shape[1] +  newmodel.fc2.weight.data.shape[0] * newmodel.fc2.weight.data.shape[1] +  newmodel.fc3.weight.data.shape[0] * newmodel.fc3.weight.data.shape[1] 
        
        AccList.append(acc)
        PerList.append(SumZero/Total *100)
    return newmodel, AccList, PerList


