# tests many different hyperparameters for the training of MLP models

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pickle
import copy

import torch

# all experimental settings
EXP_NAME = 'mnist_hyperparamer_30Kand50Kex_00.pckl'
EPOCHS = 25
NUM_TRIALS = 3
BATCH_SIZES = [128]
DATASETS = ['mnist']
LRS = [1e-8, 1e-7, 1e-6, 1e-5]
NUM_HIDS = [5000, 10000, 20000]
NUM_SIZES = [30000, 50000]

# run all experiments
results = {}
for dataset in DATASETS:
    for ns in NUM_SIZES:
        for nh in NUM_HIDS:
            for bs in BATCH_SIZES:
                for lr in LRS:
                    if ns == 30000:
                        EPOCHS = 35
                    else:
                        EPOCHS = 25
                    exp_str = f'ds{dataset}_hiddensize{nh}_bs{bs}_lr{lr}_datasize{ns}'
                    print(f'\n\n{exp_str}, {EPOCHS}')
                    avg_trn_accs = None
                    all_trn_accs = []
                    avg_trn_iter_accs = None
                    all_trn_iter_accs = []
                    for trial in range(NUM_TRIALS):
                        # run single experimental trial
                        print(f'Running Trial {trial + 1} / {NUM_TRIALS}')
                        command = (
                                f'python train_mlp.py --trn-size {ns} --num-hidden {nh}'
                                f' --lr {lr} --dataset {dataset} --downsample-method uniform'
                                f' --batch-size {bs} --epochs {EPOCHS} --exp-name {exp_str}'
                                f' --prune None')
                        os.system(command)

                        # get the results and store them for the script          
                        result_fp = os.path.join('./results', exp_str + '.pckl')
                        with open(result_fp, 'rb') as f:
                            result = pickle.load(f)
                        train_iter_accs = result['iter_accs']
                        train_accs = result['trn_acc']
                        all_trn_accs.append(train_accs)
                        if avg_trn_accs is None:
                            avg_trn_accs = copy.deepcopy(train_accs)
                        else:
                            for i in range(len(train_accs)):
                                avg_trn_accs[i] += train_accs[i]
                        all_trn_iter_accs.append(train_iter_accs)
                        if avg_trn_iter_accs is None:
                            avg_trn_iter_accs = copy.deepcopy(train_iter_accs)
                        else:
                            for i in range(len(train_iter_accs)):
                                avg_trn_iter_accs[i] += train_iter_accs[i]
                        os.remove(result_fp) # get rid of experiment file

                    # save the script results
                    avg_trn_accs = [x / NUM_TRIALS for x in avg_trn_accs]
                    avg_trn_iter_accs = [x / NUM_TRIALS for x in avg_trn_iter_accs]
                    results[exp_str] = (avg_trn_accs, all_trn_accs, avg_trn_iter_accs, all_trn_iter_accs)
                    fp = os.path.join('./results', EXP_NAME)
                    with open(fp, 'wb') as f:
                        pickle.dump(results, f)
