import pickle
import numpy as np
import pandas as pd
import argparse

'''
CREATES CSV FROM PICKLED DICTIONARY FILE
Computes averages over different independent runs as well as 
standard errors.
'''
parser = argparse.ArgumentParser()
parser.add_argument("pickledfile", help="Name of the .pkl file containing the data (include the extension)")
parser.add_argument("output_csv", help="Output file name")
args = parser.parse_args()
filename = args.pickledfile

experiment_file = open(filename, "rb")

data = pickle.load(experiment_file)

experiment_file.close()
dataset_names = None

averaged_out = []
for column in data:
    if not dataset_names:
        dataset_names = list(data[column].keys())
    averages = []
    print(column)
    for dataset in data[column]:
        arr = np.array(list(map(list, data[column][dataset])))
        avg = 100*np.average(arr, axis=0).round(decimals=2)
        std_err = 100*(np.std(arr, axis=0)/np.sqrt(5)).round(decimals=2)
        line = [avg[0], std_err[0], avg[1], std_err[1]]

        averages.append(line)
    averaged_out.append(averages)

data =  np.hstack([a for a in averaged_out])
dt = pd.DataFrame(data, index=dataset_names)
dt.columns= ['train', 'std error', 'test', 'std error']*4
dt.to_csv(args.output_csv + ".csv")
print(dt)

