import sys
import yaml
from experiment import *
from unstructured_experiment import *
import random
import numpy as np
import torch
import matplotlib
import os
from pruning_main import struc_in_layer



def main():
    #load info from config file
    matplotlib.use('Agg')
    run_var = int(sys.argv[3])
    with open(sys.argv[1], "r") as ymlfile:
        cfg = yaml.load(ymlfile, Loader= yaml.FullLoader )
        experiment_name = cfg['name']
        dataset_name = cfg['dataset']
        modelname = cfg['modelname']
        train_mode = cfg['train_mode']
        max_epochs= cfg['max_epochs']
        train_type = cfg['train_type']
        pruning_algorithms = cfg['pruning_algorithms']
        pruning_ratios = cfg['pruning_ratios']
        try:
            pretrained_model_path = cfg['pretrained_model_path']
        except:
            pretrained_model_path = None
        if pretrained_model_path is None:
            load_pretrained = False
        else:
            load_pretrained = True
        try:
            importance_vector_path = cfg['importance_vector_path']
        except:
            importance_vector_path = None
        num_iters = cfg['repetitions']
        if num_iters <=  run_var:
            cfg['repetitions'] = (run_var+1)
            num_iters = (run_var+1)

        width_factor = cfg['width_factor']
        data_dir = cfg['data_dir']
        start_seed = cfg['start_seed']
        pretrain_epochs = cfg['pretrain_epochs']
        max_finetuning_epochs = cfg['max_fine_tuning_epochs']
        samples_N = cfg['samples_N']
        unstructured = bool(cfg['unstructured'])
        shuffle_mask_layer = bool(cfg['shuffle_mask_layer'])
        pruning_layer_cap = bool(cfg['pruning_layer_cap'])
        debugging_mode = bool(cfg['debugging_mode'])
        if debugging_mode==True:
            max_epochs=1
    outdir = sys.argv[2]


    ##########################################
    ###########Start Experiment###############
    ##########################################



    if unstructured == True:
        my_experiment = Unstructured_Experiment(experiment_name, sys.argv[2], cfg,debugging_mode)
    else:
        my_experiment = Experiment(experiment_name,outdir, cfg,debugging_mode)


    my_experiment.load_data(dataset_name,data_dir)
    my_experiment.prepare_iteration(run=run_var)
    my_experiment.create_model(modelname, (start_seed + run_var), width_factor,pretrained_model_path,load_pretrained) #here I used to have a +10 after run_var


    if train_type != 'baseline':
        my_experiment.prepare_model(train_type,train_mode,max_epochs,load_pretrained,pretrain_epochs)


    if unstructured == True:
        raise NotImplementedError

    elif train_type == 'baseline':
        best_acc = my_experiment.train_model(my_experiment.model, train_mode, max_epochs,iter = run_var,save_best_model=True)
    else:
        my_experiment.prepare_pruning(pruning_algorithms, samples_N,importance_vector_path)
        if train_type=='initmaskretrain':
            my_experiment.train_model(my_experiment.model, train_mode, max_epochs)
        if train_type=='rewind':
            my_experiment.create_model(modelname, (start_seed + run_var), width_factor)


        if train_type=='expand_init' or train_type == 'expand_retrain':

            pruning_percentages = None

            if len(my_experiment.pruning_importance_vectors) == 1:

                strucs_per_layer = []
                for lay in range(my_experiment.model_stats.num_layers):
                    n_strucs = int(my_experiment.model_stats.nn_shapes_prunable_layers[lay][0])
                    strucs_per_layer.append(n_strucs)

                strucs_per_layer = np.array(strucs_per_layer)
                pruned_strucs_per_layer = np.zeros_like(strucs_per_layer)


                #we choose a 50% pruning ratio to detect bottlenecks
                for pr_idx in range(int(len(my_experiment.pruning_importance_vectors[0]) * 0.5)):
                    current_struc_layer = struc_in_layer(my_experiment.model_stats.num_layers, my_experiment.model_stats.nn_shapes_prunable_layers,
                                                         my_experiment.pruning_importance_vectors[0][pr_idx])

                    pruned_strucs_per_layer[current_struc_layer] += 1

                pruning_percentages = pruned_strucs_per_layer / strucs_per_layer

            else:
                raise Exception('This mode only works for a single pruning method')


            if my_experiment.model.model_type == 'vgg19':
                #use the 5 leat pruned layers and expand them
                mult_arr = np.argsort(pruning_percentages)[:5]
                my_experiment.create_model(modelname, (start_seed + run_var), width_factor, pretrained_model_path,
                                           load_pretrained,mult_arr)
                my_experiment.prepare_model( train_type, train_mode, max_epochs, load_pretrained, pretrain_epochs)
                #todo: expand this to rewind and initmaskretrain
                my_experiment.prepare_pruning(pruning_algorithms, samples_N, importance_vector_path)

            elif my_experiment.model.model_type == 'resnet56' or my_experiment.model.model_type == 'resnet32':
                if my_experiment.model.model_type == 'resnet56':
                    block_arr = [9,9,9]
                else:
                    block_arr = [5,5,5]
                block_percentages = [0,0,0]#np.zeros_like(block_arr)
                intermediate = 0
                idx = 0
                div = float(2 * block_arr[0]+1)
                for per_idx,per_el in enumerate(pruning_percentages):
                    intermediate += per_el
                    if per_idx % (2*block_arr[0])==0 and per_idx != 0:
                        block_percentages[idx] = intermediate/div
                        intermediate = 0
                        idx += 1
                        div = float(2 * block_arr[0])
                mult_arr = np.ones_like(block_percentages)
                #expand the least pruned block
                mult_arr[np.argmin(block_percentages)]= 2.0
                my_experiment.create_model(modelname, (start_seed + run_var), width_factor, pretrained_model_path,
                                           load_pretrained, mult_arr)
                my_experiment.prepare_model(train_type, train_mode, max_epochs, load_pretrained, pretrain_epochs)
                my_experiment.prepare_pruning(pruning_algorithms, samples_N, importance_vector_path)

            else:
                raise NotImplementedError


        my_experiment.start_pruning(pruning_ratios, train_mode, max_epochs, train_type, max_finetuning_epochs,(start_seed + run_var),shuffle_mask_layer,my_experiment.imagenet,pruning_layer_cap)


    if train_type == 'baseline':
        best_path = os.path.join(my_experiment.exp_fol_dir, 'iter_' + str(run_var) + '/best_acc.npy')
        np.save(best_path,np.array([best_acc]))
        best_acc_finished = True
        list_best_accs= []
        for iter in range(num_iters):
            path_to_best_acc_tmp = os.path.join(my_experiment.exp_fol_dir, 'iter_'+str(iter)+'/best_acc.npy')
            best_acc_finished = (os.path.isfile(path_to_best_acc_tmp)) and best_acc_finished
            if best_acc_finished==True:
                list_best_accs.append(np.load(path_to_best_acc_tmp))
        if best_acc_finished==True:
            best_accs_path = os.path.join(my_experiment.exp_fol_dir, 'best_accs_list.txt')
            np.savetxt(best_accs_path,np.array(list_best_accs))
    else:
        my_experiment.plot_results(pruning_ratios, modelname, train_type)

        #check if all files already exist
        files_finished = True
        for iter in range(num_iters):
            path_to_acc_tmp_1 = os.path.join(my_experiment.exp_fol_dir, 'iter_'+str(iter)+'/model_accs.npy')
            path_to_acc_tmp_2 = os.path.join(my_experiment.exp_fol_dir, 'iter_' + str(iter) + '/pruning_param_ratios.npy')
            path_to_acc_tmp_3 = os.path.join(my_experiment.exp_fol_dir, 'iter_' + str(iter) + '/pruning_ratios.npy')
            files_exist = (os.path.isfile(path_to_acc_tmp_1) and os.path.isfile(path_to_acc_tmp_2)) and os.path.isfile(path_to_acc_tmp_3)
            files_finished = files_exist and files_finished


        if files_finished==True and num_iters>=0:
            print('Creating average plots')
            my_experiment.plot_average_results(pruning_ratios, modelname, train_type, num_iters)

if __name__ == "__main__":
    main()
