#The experiment class connects all the different part stores files and starts pruning and training
from models.models import *
from pruning_methods import *
from pruning_main import *
import matplotlib.pyplot as plt
import random
import numpy as np
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import time
import torch.optim as optim
from torch.autograd import Variable
from numpy import save,load
from utilities import *
import copy
from train import pre_train_nn
import os
from train import *
from pruning_utilities import *
import torch.nn.utils.prune as prune
from utilities import validate
import yaml
#from pruning_utilities import calc_parameters_pruned_resnet

#todo: save yaml-file in exp folder
#todo: make loading a state-dict an option instead of full training it

class Experiment:
    def __init__(self,experiment_name,outdir,yaml_file,debugging_mode):
        print('-----------------------------------------------------')
        print('Staring experiment ',experiment_name )
        print('-----------------------------------------------------')

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.name = experiment_name
        self.pruning_algorithm_name = []
        self.pruning_importance_vectors = []
        self.debugging_mode = debugging_mode


        cwd = outdir    #os.getcwd()
        cwd_exp =  os.path.join(cwd, 'experiments')
        self.exp_fol_dir = os.path.join(cwd_exp, self.name)

        try:
            os.mkdir(self.exp_fol_dir)
            print('Created new folder: ', self.exp_fol_dir)
            path = os.path.join(self.exp_fol_dir, 'experiment.yml')
            with open(path, 'w') as file:
                yaml.dump(yaml_file, file)
        except FileExistsError:
            print('Directory already exists.')



    def prepare_iteration(self,run=0):
        print('-----------------------------------------------------')
        print('Staring Iteration ',str(run))
        print('-----------------------------------------------------')
        self.iter_folder_dir = os.path.join(self.exp_fol_dir, str('iter_')+ str(run))
        os.mkdir(self.iter_folder_dir)

    def load_data(self,dataset_name,data_dir):
        print('Loading: ',dataset_name)
        self.imagenet = False
        self.trainloader, self.testloader = get_datalaoders(dataset_name,data_dir)
        if dataset_name == 'Cifar100':
            self.num_outputs = 100
        elif dataset_name=='ImageNet':
            self.num_outputs=1000
            self.imagenet = True
        elif dataset_name == 'Cifar10':
            self.num_outputs = 10
        else:
            raise Exception('Unknown dataset')


    def create_model(self,model_name,set_seed,width_factor =1,pretrained_model_path=None,load_pretrained=False,mult_arr=None):
        # set random seeds
        np.random.seed(set_seed)
        random.seed(set_seed)
        torch.manual_seed(set_seed)
        torch.cuda.manual_seed(set_seed)
        torch.cuda.manual_seed_all(set_seed)
        torch._set_deterministic(True)
        #torch.backends.cudnn.benchmarks = False
        #torch.backends.cudnn.deterministic = True

        self.criterion = nn.CrossEntropyLoss()
        self.model = create_model_by_name(model_name,self.device,width_factor=width_factor,num_classes=self.num_outputs,pretrained_model_path=pretrained_model_path,load_pretrained=load_pretrained,mult_arr=mult_arr)
        self.model_stats = model_stats(self.model)
        self.model_stats.print_model_stats()


        if self.num_outputs== 1000:
            print('Initial acc:',str(validate(self.testloader,self.model,self.criterion,self.device,initial=False)))
        else:
            if load_pretrained==False:
                print('Inital loss: ', str(loss_net(self.model,self.criterion,self.trainloader,self.device)))
            self.model.eval()
            print('Inital acc: ', str(acc_net(self.model,self.testloader, self.device)))
        self.model.train()
        try:
            print('The model Type is: ' + self.model.model_type)
        except:
            raise

    def prepare_model(self,train_type,train_mode,max_epochs,load_pretrained,pretrain_epochs):

        if (train_type == 'train' or train_type == 'retrain' or train_type == 'rewind' or train_type == 'expand_retrain') and not (
        load_pretrained):
            print('Training Net for ' + str(max_epochs) + ' epochs')
            self.train_model(self.model, train_mode, max_epochs)
            #print('Start second training')
            #self.train_model(self.model, train_mode, max_epochs)

        elif train_type == 'init' or train_type == 'doubletrain' or train_type == 'initmaskretrain' or train_type == 'expand_init' or load_pretrained:
            print('Proceeding with initialized Net')

        elif train_type == 'pretraining':
            print('Pretraining initialized net for ' + str(pretrain_epochs) + ' epochs')
            self.pretrain_model(pretrain_epochs)
        else:
            raise Exception(
                'Train_type unknown must be either "train", "init", "rewind", "retrain", "doubletrain" ,"initmaskretrain","expand_retrain" or "pretraining"')

        print("Currently models are not saved due to space limitations")
        # my_experiment.save_model()

    def save_model(self,iter):
        print('-----------------------------------------------------')
        path = os.path.join(self.iter_folder_dir, str(self.model_type) + '_'+str(iter)+".pt")
        if self.num_outputs== 1000:
            print('Saving currently disabled for imagenet')
            raise
        else:
            print('Saving model to '+ str(path))
            torch.save(self.model.state_dict(), path)

    def train_cifar_model(self,model,train_mode, max_epochs):
        print('-----------------------------------------------------')
        print('Staring training: ')
        self.model.train()
        best_acc,state_dict = full_cifar_train_nn(model, self.trainloader, self.testloader, self.criterion, self.device, train_mode, max_epochs)
        self.model.eval()
        return best_acc,state_dict

    def train_imagenet_model(self,model, max_epochs):
        print('-----------------------------------------------------')
        print('Staring training: ')
        self.model.train()
        best_acc = full_imagenet_train_nn(model, self.trainloader, self.testloader, self.criterion, self.device,max_epochs)
        self.model.eval()
        return best_acc

    def train_model(self, model,train_mode, max_epochs,iter = 0, save_best_model= False):
        if self.imagenet == True:
            best_acc = self.train_imagenet_model(model,max_epochs)
        else:
            best_acc,state_dict = self.train_cifar_model(model,train_mode, max_epochs)
            if save_best_model == True:
                path = os.path.join(self.iter_folder_dir, str(self.model.model_type) + '_' + str(iter) + ".pt")
                torch.save(state_dict, path)
        return best_acc

    def pretrain_model(self, pretrain_epochs):
        print('-----------------------------------------------------')
        print('Staring training: ')
        self.model.train()
        #todo:inspect
        pre_train_nn(self.model, self.trainloader, self.testloader, self.criterion, self.device, pretrain_epochs)
        self.model.eval()


    def prepare_pruning(self,pruning_algorithms,samples_N,importance_vector_path=None):
        print('-----------------------------------------------------')
        print('Preparing for pruning: ')
        self.model.eval()
        print('Calculating first order terms: ')
        self.first_order = calc_first_order(self.model,self.trainloader,self.model_stats.num_strucs, self.model_stats.layer_names,self.device)

        print('Calculating second order terms: ')
        print(samples_N)
        if self.debugging_mode==True:
            print('Using first order as second_order in debugging mode!')
            self.second_order = self.first_order
        else:
            #Only uncomment when using SOSP-H to speed up experiments since SOSP-H does not use second-order directly
            #print('Using first order as second_order mode!CHANGE THIS BACK AFTER IMIAGENET EXPERIMENTS')
            #self.second_order = self.first_order
            self.second_order = calc_second_order(self.model,self.trainloader,self.model_stats.num_strucs, self.model_stats.layer_names,self.model_stats.num_layers,self.device,self.num_outputs,samples_N)
        #todo: save the two matrices
        self.model.train()
        #uncomment to save the thwo matrices
        #path_fo = os.path.join(self.iter_folder_dir, "first_order")
        # np.save(path_fo, np.array(self.first_order))
        #path_so = os.path.join(self.iter_folder_dir, "second_order")
        #np.save(path_so, np.array(self.second_order))


        self.pruning_algorithm_name = []
        self.pruning_importance_vectors = []

        print('Creating importance vectors:')
        for pruning_algo in pruning_algorithms:
            print('Preparing '+ str(pruning_algo))
            self.pruning_algorithm_name.append(pruning_algo)
            if importance_vector_path is None:
                importance_vector = create_impoprtance_vector_by_name(pruning_algo, self.first_order, self.second_order,
                                                                      self.model, self.model_stats.layer_names,
                                                                      self.criterion, self.device, self.trainloader)

                self.pruning_importance_vectors.append(importance_vector)
        if importance_vector_path is not None:
            self.pruning_importance_vectors = np.load(importance_vector_path)
        path_iv = os.path.join(self.iter_folder_dir, "importance_vectors")
        np.save(path_iv, np.array(self.pruning_importance_vectors))


    def start_pruning(self,pruning_ratios,train_mode, max_epochs,train_type,max_finetuning_epochs,set_seed,shuffle_mask_layer=False,imagenet=False,pruning_layer_cap=False):
        # not a nice solution with global variable but for now a fix
        global pruning_channel_oh

        self.pruning_para_ratio = []
        self.pruning_macs_ratio = []

        self.pruning_best_acc = []
        self.pruning_acc = []
        self.pruning_loss = []

        self.unpruned_strucs = []
        self.pruned_strucs = []
        self.pruning_channel_list_oh = []

        self.mac_counts = []
        self.parameter_counts = []

        for pruning_name in self.pruning_algorithm_name:
            self.pruning_para_ratio.append([])
            self.pruning_macs_ratio.append([])

            self.pruning_acc.append([])
            self.pruning_best_acc.append([])
            self.pruning_loss.append([])
            self.unpruned_strucs.append([])
            self.pruned_strucs.append([])
            self.pruning_channel_list_oh.append([])

            self.mac_counts.append([])
            self.parameter_counts.append([])

        for prun_idx in range(len(self.pruning_algorithm_name)):
            np.random.seed(set_seed)
            random.seed(set_seed)
            torch.manual_seed(set_seed)
            torch.cuda.manual_seed(set_seed)
            torch.cuda.manual_seed_all(set_seed)
            print('-------------------------------------------------------')
            print('Starting pruning importance: ' + str(self.pruning_algorithm_name[prun_idx]))
            for prun_num,prun_ratio in enumerate(pruning_ratios):
                print('Pruning Ratio: ' + str(prun_ratio))
                net_struct_prune_iterative = copy.deepcopy(self.model)
                structs_to_prune_iterative = define_struct_to_prune(net_struct_prune_iterative,self.model_stats.layer_names)


                pruning_chanel_oh = []
                prune.global_unstructured(structs_to_prune_iterative, pruning_method=struct_importance_pruning,
                                          amount=prun_ratio, shape=self.model_stats.nn_shapes_prunable_layers,
                                          importance_matrix_idx=self.pruning_importance_vectors[prun_idx],pruning_chanel_oh= pruning_chanel_oh,
                                          shuffle_mask_layer=shuffle_mask_layer,pruning_layer_cap=pruning_layer_cap)


                #todo: add here optimistic pesimistic and correct calculartion if available
                low_bound_pruned_para_ratio,unprun_struc,pruned_struc = underestimated_calc_parameters_pruned_resnet(net_struct_prune_iterative, structs_to_prune_iterative)




                #todo: later include this correclty without this trick
                self.pruning_channel_list_oh[prun_idx].append(pruning_chanel_oh)
                #self.pruning_para_ratio[prun_idx].append(prun_para_ratio)
                self.pruned_strucs[prun_idx].append(pruned_struc)
                self.unpruned_strucs[prun_idx].append(unprun_struc)


                #todo: add total parameter and mac count ....
                if self.model.model_type == 'resnet56':
                    correct_numb_params,correct_numb_macs  = resnet56_param_count( np.array(self.model_stats.layer_names), np.array(self.pruning_channel_list_oh[prun_idx]),
                                                                                   np.array(self.pruned_strucs[prun_idx]),  np.array(self.unpruned_strucs[prun_idx]), ratio_num=prun_num, num_classes=self.num_outputs )
                elif self.model.model_type == 'resnet32':
                    correct_numb_params,correct_numb_macs = resnet32_param_count( np.array(self.model_stats.layer_names), np.array(self.pruning_channel_list_oh[prun_idx]),
                                                                                  np.array(self.pruned_strucs[prun_idx]),  np.array(self.unpruned_strucs[prun_idx]), ratio_num=prun_num, num_classes=self.num_outputs )
                elif self.model.model_type == 'resnet18_I':
                    correct_numb_params,correct_numb_macs  = resnet18_param_count( np.array(self.model_stats.layer_names), np.array(self.pruning_channel_list_oh[prun_idx]),
                                                                                  np.array(self.pruned_strucs[prun_idx]),  np.array(self.unpruned_strucs[prun_idx]), ratio_num=prun_num, num_classes=self.num_outputs )
                elif self.model.model_type == 'resnet50_I':
                    correct_numb_params,correct_numb_macs  = resnet50_param_count( np.array(self.model_stats.layer_names), np.array(self.pruning_channel_list_oh[prun_idx]),
                                                                                  np.array(self.pruned_strucs[prun_idx]),  np.array(self.unpruned_strucs[prun_idx]), ratio_num=prun_num, num_classes=self.num_outputs )
                elif self.model.model_type == 'densenet40':
                    correct_numb_params, correct_numb_macs = densenet40real_param_count( np.array(self.model_stats.layer_names), np.array(self.pruning_channel_list_oh[prun_idx]),
                                                                                  np.array(self.pruned_strucs[prun_idx]),  np.array(self.unpruned_strucs[prun_idx]), ratio_num=prun_num, num_classes=self.num_outputs )
                elif self.model.model_type == 'mobilenetv2_I':
                    correct_numb_params,correct_numb_macs = None,None
                elif self.model.model_type == 'vgg19':
                    correct_numb_params,correct_numb_macs = vgg19_param_count(net_struct_prune_iterative, structs_to_prune_iterative)
                else:
                    correct_numb_params,correct_numb_macs = None,None

                self.mac_counts[prun_idx].append(correct_numb_macs)
                self.parameter_counts[prun_idx].append(correct_numb_params)

                print("The number of Parameters/Macs are: " +str(correct_numb_params)+ '/ '+ str(correct_numb_macs))

                if correct_numb_params is None:
                    para_ratio = low_bound_pruned_para_ratio
                else:
                    para_ratio = float(correct_numb_params)/float(self.model_stats.num_params)

                if (correct_numb_macs is None) or (self.model.macs_forward is None):
                    mac_ratio = None
                else:
                    mac_ratio = float(correct_numb_macs)/float(self.model.macs_forward )

                print("The ratios are of Parameters/Macs are: " + str(para_ratio) + '/ ' + str(mac_ratio))

                self.pruning_macs_ratio[prun_idx].append(mac_ratio)
                self.pruning_para_ratio[prun_idx].append(para_ratio)

                net_struct_prune_iterative.train()


                if train_type== 'train':
                    print('Starting finetuning:')
                    if imagenet:
                        raise #print('Not implemented for imagenet')
                    else:
                        best_acc = fine_tune_nn(net_struct_prune_iterative, self.trainloader, self.testloader, self.criterion, self.device, max_finetuning_epochs)
                else:
                    print('Starting full training:')
                    best_acc = self.train_model(net_struct_prune_iterative,train_mode,max_epochs)

                    if train_type=='doubletrain':
                        best_acc = self.train_model(net_struct_prune_iterative,train_mode, max_epochs)

                self.pruning_best_acc[prun_idx].append(best_acc)
                print('The best performing Model has acc: '+str(best_acc))
                net_struct_prune_iterative.eval()
                if self.imagenet==True:
                    acc_tmp = validate(self.testloader,net_struct_prune_iterative,self.criterion,self.device)
                    print('The final model has acc: '+ str(acc_tmp))
                    self.pruning_acc[prun_idx].append(acc_tmp)
                    self.pruning_loss[prun_idx].append(0)
                else:
                    print('The final model has acc: '+ str(acc_net(net_struct_prune_iterative,self.testloader,self.device)))
                    self.pruning_acc[prun_idx].append(acc_net(net_struct_prune_iterative,self.testloader,self.device))
                    self.pruning_loss[prun_idx].append(loss_net(net_struct_prune_iterative,self.criterion,self.testloader,self.device))

        self.save_files(pruning_ratios)


    def save_files(self,pruning_ratios):
        #todo: save additional values .....


        path_acc = os.path.join(self.iter_folder_dir, "model_accs")
        path_best_acc = os.path.join(self.iter_folder_dir, "model_best_accs")
        path_loss = os.path.join(self.iter_folder_dir, "model_losses")


        path_unpruned_strucs = os.path.join(self.iter_folder_dir, "unpruned_strucs")
        path_pruned_strucs = os.path.join(self.iter_folder_dir, "pruned_strucs")
        path_pruning_channel_list_oh = os.path.join(self.iter_folder_dir, "pruned_channels_oh")
        path_layer_names = os.path.join(self.iter_folder_dir, "layer_names")

        path_pruning_ratios = os.path.join(self.iter_folder_dir, "pruning_ratios")
        path_pruned_params = os.path.join(self.iter_folder_dir, "pruned_params")
        path_pruned_macs = os.path.join(self.iter_folder_dir, "pruned_macs")
        path_pruning_param_ratios = os.path.join(self.iter_folder_dir, "pruning_param_ratios")
        path_pruning_mac_ratios = os.path.join(self.iter_folder_dir, "pruning_mac_ratios")

        np.save(path_pruning_channel_list_oh, np.array(self.pruning_channel_list_oh))
        np.save(path_layer_names, np.array(self.model_stats.layer_names))
        np.save(path_acc, np.array(self.pruning_acc))
        np.save(path_best_acc, np.array(self.pruning_best_acc))
        np.save(path_loss, np.array(self.pruning_loss))
        np.save(path_unpruned_strucs, np.array(self.unpruned_strucs))
        np.save(path_pruned_strucs, np.array(self.pruned_strucs))

        np.save(path_pruning_ratios, np.array(pruning_ratios))
        np.save(path_pruning_param_ratios, np.array(self.pruning_para_ratio))
        np.save(path_pruning_mac_ratios, np.array(self.pruning_macs_ratio))

        np.save(path_pruned_params, np.array(self.parameter_counts))
        np.save(path_pruned_macs, np.array(self.mac_counts))




    def plot_results(self,pruning_ratios,modelname,train_type):
        print('-----------------------------------------------------')
        print('Plotting Results')
        for idx in range(len(self.pruning_algorithm_name)):
            plt.plot(1.0-np.array(self.pruning_para_ratio[idx]), self.pruning_acc[idx], label=self.pruning_algorithm_name[idx])
        plt.legend()
        plt.xlabel('Pruned Parameters Ratio')
        plt.title(modelname + ' at ' + train_type)
        plt.ylabel('Accuracy')
        path_para_acc_plt = os.path.join(self.iter_folder_dir, 'accuracies_vs_pruned_parameters')
        plt.savefig(path_para_acc_plt)
        plt.close()

        for idx in range(len(self.pruning_algorithm_name)):
            plt.plot(pruning_ratios, self.pruning_acc[idx], label=self.pruning_algorithm_name[idx])
        plt.legend()
        plt.xlabel('Pruned Structures Ratio')
        plt.title(modelname + ' at ' + train_type)
        plt.ylabel('Accuracy')
        path_struc_acc_plt = os.path.join(self.iter_folder_dir, 'accuracies_vs_pruned_structurs')
        plt.savefig(path_struc_acc_plt)
        plt.close()

    def define_pruning_algorithm_names(self,pruning_algorithms):
        for pruning_algo in pruning_algorithms:
            self.pruning_algorithm_name.append(pruning_algo)


    def plot_average_results(self,pruning_ratios,modelname,train_type,runs):
        para_prune_ratio = []
        total_results = []
        for run in range(runs):
            path_iter = os.path.join(self.exp_fol_dir, str('iter_') + str(run))
            path_to_pruning_param = os.path.join(path_iter,"pruning_param_ratios.npy")
            path_to_acc = os.path.join(path_iter, "model_accs.npy")
            para_prune_ratio.append(np.abs(1 - np.load(path_to_pruning_param)))
            total_results.append(np.load(path_to_acc))

        para_prune_ratio_mean = np.mean(para_prune_ratio, axis=0)
        para_prune_ratio_std = np.std(para_prune_ratio, axis=0)

        total_results_mean = np.mean(total_results, axis=0)
        total_results_std = np.std(total_results, axis=0)

        #plot pruned parameter distribution
        for idx in range(len(self.pruning_algorithm_name)):
            plt.errorbar(pruning_ratios, para_prune_ratio_mean[idx], yerr=para_prune_ratio_std[idx], capsize=5, label=self.pruning_algorithm_name[idx])
        plt.legend()
        plt.xlabel('Pruned Structs Ratio')
        plt.title(modelname + ' at ' + train_type +'Paramters Ratios')
        plt.ylabel('Pruned Parameters Ratio')
        save_path_para_plot = os.path.join(self.exp_fol_dir, 'parameters_structs_ratio')
        plt.savefig(save_path_para_plot)
        plt.close()

        # plot pruned acc over pruned structures
        for idx in range(len(self.pruning_algorithm_name)):
            plt.errorbar(pruning_ratios, total_results_mean[idx], yerr=total_results_std[idx], capsize=5, label=self.pruning_algorithm_name[idx])
        plt.legend()
        plt.xlabel('Pruned Structs Ratio')
        plt.title(modelname + ' at ' + train_type)
        plt.ylabel('Accuracy')
        save_path_av_acc_plot = os.path.join(self.exp_fol_dir, 'average_acc')
        plt.savefig(save_path_av_acc_plot)
        plt.close()

        # plot pruned acc over pruned parameters
        for idx in range(len(self.pruning_algorithm_name)):
            plt.errorbar(para_prune_ratio_mean[idx], total_results_mean[idx], yerr=total_results_std[idx], capsize=5, label=self.pruning_algorithm_name[idx])
        plt.legend()
        plt.xlabel('Pruned Parameters Ratio')
        plt.title(modelname + ' at ' + train_type + 'Parameters' )
        plt.ylabel('Accuracy')
        save_path_av_acc_para_plot = os.path.join(self.exp_fol_dir, 'average_acc_parameters')
        plt.savefig(save_path_av_acc_para_plot)
        plt.close()
