import pandas as pd
import numpy as np
from pretrained import whole
from  AP2_computation import AP2_class
from Montecarlo import montecarlo
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import loadingdata
from AP3_Gaussian import AP3_Gaussian_class


class main():

    def __init__(self, pruned_epoch_number=20, trained_epoch_number=100, v=4, group=100, network='vgg16', data='CIFAR10', n_samples=600000):
        self.pruned_epoch_number = pruned_epoch_number
        self.trained_epoch_number = trained_epoch_number
        self.dataframes = {}
        self.device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.v= v
        self.group= group
        self.network= network
        self.data= data
        self.n_samples= n_samples
        self.base_path= 'path address where you would like to save figures'
        self.dataset = loadingdata.Dataset(data)
        self.trainset, self.testset = self.dataset.data_reader()

    def execute(self):
        Average={}       
        for itera in [1,2,3]:
            self.dataframes = {}
            
            # first train the pre-trained benchmark
            pretrain = whole(self.device, network_name=self.network, data=self.data, batch_size=128,
                                epoch_number=self.trained_epoch_number, flag='True', itera=itera, trainset=self.trainset, testset=self.testset)
            original_model, acc_data, vanila_concat_set =pretrain.function(trained_epoch_num=self.trained_epoch_number,
                                                            pruned_step='False',
                                                            percentage=None, method=None, model=pretrain.model,
                                                            mask=None)

            vanila_concat=vanila_concat_set['{}'.format(self.trained_epoch_number-1)]


            lambda_max=[]
            for pruned_method in ['lowest','highest','random']:

                for percentage in [0.1, 0.3, 0.5, 0.8]:
                    print(percentage)

                    variable_name = "{}-{}".format(pruned_method,percentage)

                    dataframe = pd.DataFrame(np.zeros((self.pruned_epoch_number, 4)), columns=['AP2', 'AP3', 'Performance_Difference_train', 'Performance_Difference_test'])
                    # for epoch in range(self.pruned_epoch_number):

                    #fine-tunning the pruned network
                    pruned_model, pruned_acc_data, pruned_concat_set = pretrain.finetunning_pruned(
                        percentage=percentage, model=original_model, pruned_epoch_number=self.pruned_epoch_number, pruned_method=pruned_method)

                    # if you would like to have AP3 based on Gaussian with non-diagonal covaraince, uncomment line 62 and comment line 63.
                    # Note that for alexnet and vgg as the covariance is too large, you need more RAM, CPU and GPU to construct it.
                    # AP3_list = AP3_Gaussian_class(self.device, vanila_concat, pruned_concat_set, self.pruned_epoch_number).AP3_function()
                    AP3_list=montecarlo(self.device, vanila_concat, pruned_concat_set, itera, self.n_samples, self.v, self.group, self.pruned_epoch_number).AP3_computation()
                    AP2_list=[AP2_class(self.device, vanila_concat,pruned_concat_set['{}'.format(i)]).AP2_function() for i in range (self.pruned_epoch_number)]
                    PD_train_list=[np.abs(acc_data.iloc[0,self.trained_epoch_number-1]-pruned_acc_data.iloc[0,i]) for i in range(self.pruned_epoch_number)]
                    PD_test_list=[np.abs(acc_data.iloc[1,self.trained_epoch_number-1]-pruned_acc_data.iloc[1,i]) for i in range(self.pruned_epoch_number)]
                    dataframe['AP3']= AP3_list
                    dataframe['AP2']= AP2_list
                    dataframe['Performance_Difference_train']= PD_train_list
                    dataframe['Performance_Difference_test']= PD_test_list
                    # in order to finc the suitable coffeicient of identity matrix as the covariance matrix for Gaussian, we need the following variable.
                    lambda_max.append(AP2_list)
                    self.dataframes[variable_name] = dataframe
                    # model = model

            self.dataframes = self.modify_first_column(lambda_param=lambda_max, DataFrame=self.dataframes)
            Average['{}'.format(itera)] = self.dataframes
        final_datasets={'{}-{}'.format(i, j): (Average['{}'.format(1)]['{}-{}'.format(i,j)]+Average['{}'.format(2)]['{}-{}'.format(i, j)]+Average['{}'.format(3)]['{}-{}'.format(i, j)])/3 for i in ['lowest','highest','random'] for j in [0.1,0.3,0.5,0.8]}
        return final_datasets

    def modify_first_column(self, lambda_param=None,DataFrame=None):
        '''
        this function helps to find the covariance matrix of gaussian in a way that all AP2 values are less that 2.
        '''
        for key in DataFrame.keys():
            DataFrame[key]['AP2'] = DataFrame[key]['AP2'] * ((2*1.9)/(np.max(lambda_param)))

        return DataFrame

    def plot_test(self):
        
        '''
        ploting figures
        '''

        DataSet=self.execute()

        # param can be 'AP2' or 'AP3'
        colors = ['red', 'blue', 'orange', 'purple']
        for param in ['AP2','AP3']:
            for method in ['lowest','highest','random']: #

                fig, ax = plt.subplots()

                for i, percentage in enumerate([0.1,0.3,0.5,0.8]):
                    data1=DataSet['{}-{}'.format(method, percentage)]
                    data1.index = range(1, self.pruned_epoch_number+1)
                    if param=='AP2':
                        #AP2
                        ax.plot(data1.iloc[:self.pruned_epoch_number+1, 0], linewidth=3, color=colors[i])
                    else:
                        #AP3
                        ax.plot(data1.iloc[:self.pruned_epoch_number+1, 1], linewidth=3, color=colors[i])
                    #Performance difference on test data
                    ax.plot(data1.iloc[:self.pruned_epoch_number+1, 3], linewidth=3, linestyle='--', color=colors[i])

                # Axis formatting
                ax.xaxis.grid(True)
                ax.yaxis.grid(True)
                ax.set_yscale("log")
                ax.set_xlabel("Epoch")
                save_path = Path(
                    '{}-{}-log-all percentages-method{}-kl-{}-cifar-test.png'.format(self.v, self.group, method,
                                                                                            self.network))
                fig.savefig(self.base_path/save_path, format='png')



        colors = ['red', 'blue', 'orange']
        for percentage in [0.1, 0.3, 0.5, 0.8]:
            fig, ax = plt.subplots()
            for i, method in enumerate(['lowest','highest','random']):
                data1 = DataSet['{}-{}'.format(method, percentage)]
                data1.index = range(1, self.pruned_epoch_number+1)
                if param=='AP2':
                    #AP2
                    ax.plot(data1.iloc[:self.pruned_epoch_number+1, 0], linewidth=3, color=colors[i])
                else:
                    #AP3
                    ax.plot(data1.iloc[:self.pruned_epoch_number+1, 1], linewidth=3, color=colors[i])
                # Performance difference on test data
                ax.plot(data1.iloc[:self.pruned_epoch_number+1, 3], linewidth=3, linestyle='--', color=colors[i])

            ax.xaxis.grid(True)
            ax.yaxis.grid(True)
            ax.set_yscale("log")
            ax.set_xlabel("Epoch")
            save_path = Path(
                '{}-{}-log-all methods-method{}-kl-{}-cifar-test.png'.format(self.v, self.group, method, self.network))
            fig.savefig(self.base_path/save_path, format='png')
        return
    


if __name__ == '__main__':
    for net in ['resnet', 'alexnet', 'vgg16']:
        for data in ['CIFAR10', 'CIFAR100']:
            s = main( pruned_epoch_number=20, trained_epoch_number=100, v=4, group=100, network=net, data=data, n_samples=600000)
            s.plot_test()


