# pruning tests for two-layer MLPs

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pickle
import copy 

import torch

# all experimental settings
EXP_NAME = 'mnist_pruning_20Khid_00.pckl'
NUM_TRIALS = 2
EPOCHS = [1000, 200, 100, 67, 50, 40, 34, 29, 25, 23, 20]
DATASET = 'mnist'
LRS = [1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6]
NUM_HIDS = 20000
PRUNE_SIZE = 200
NUM_SIZES = [
        1000, 5000, 10000, 15000, 20000, 25000,
        30000, 35000, 40000, 45000, 50000,
]
PRUNE_FREQ = 1000
BS = 128

# run all experiments
results = {}
for e, ns, lr in zip(EPOCHS, NUM_SIZES, LRS):
    exp_str = f'datasize{ns}'
    all_prune_accs = []
    avg_prune_accs = None
    all_trn_accs = []
    avg_trn_accs = None
    all_trn_iter_accs = []
    avg_trn_iter_accs = None
    for trial in range(NUM_TRIALS):
        print(f'\n\nRunning {exp_str}, Trial {trial + 1} / {NUM_TRIALS}\n\n')

        # run single experimental trial
        command = (
                f'python train_mlp.py --trn-size {ns} --num-hidden {NUM_HIDS}'
                f' --lr {lr} --dataset {DATASET} --downsample-method uniform'
                f' --batch-size {BS} --prune all_epochs --prune-size {PRUNE_SIZE}'
                f' --prune-freq {PRUNE_FREQ} --epochs {e} --verbose'
                f' --exp-name {exp_str} --prune-epochs 0 --prune-by-iter')
        os.system(command)

        # get the results and store them for the script          
        result_fp = os.path.join('./results', exp_str + '.pckl')
        with open(result_fp, 'rb') as f:
            result = pickle.load(f)
        train_accs = result['trn_acc']
        all_trn_accs.append(train_accs)
        prune_accs = result['prune_acc']
        all_prune_accs.append(prune_accs)
        train_iter_accs = result['iter_accs']
        all_trn_iter_accs.append(train_iter_accs)
        if avg_prune_accs is None:
            avg_prune_accs = prune_accs
        else:
            for i in range(len(prune_accs)):
                avg_prune_accs[i] += prune_accs[i]
        if avg_trn_accs is None:
            avg_trn_accs = train_accs
        else:
            for i in range(len(train_accs)):
                avg_trn_accs[i] += train_accs[i]
        if avg_trn_iter_accs is None:
            avg_trn_iter_accs = copy.deepcopy(train_iter_accs)
        else:
            for i in range(len(train_iter_accs)):
                avg_trn_iter_accs[i] += train_iter_accs[i]
        os.remove(result_fp) # get rid of experiment file

    # save the script results
    avg_trn_accs = [x / NUM_TRIALS for x in avg_trn_accs]
    avg_prune_accs = [x / NUM_TRIALS for x in avg_prune_accs]
    results[exp_str] = (
            avg_trn_accs, all_trn_accs, avg_trn_iter_accs, all_trn_iter_accs,
            avg_prune_accs, all_prune_accs
    )
    fp = os.path.join('./results', EXP_NAME)
    with open(fp, 'wb') as f:
        pickle.dump(results, f)
