import nninfo
import pandas as pd
import sys

def compress_measurements(exp_name, measurements_subdir='measurements'):

    exp_path = f'../experiments/exp_{exp_name}/'
    full_path = nninfo.FileManager(exp_path, write=True)._path

    # Load dataframe
    print(f'Loading df from {exp_path}{measurements_subdir}...')
    mm = nninfo.file_io.MeasurementManager(exp_path, measurement_subdir=measurements_subdir)
    df = mm.load(tuple_as_str=True)

    # Flatten PID dictionaries
    print('Flattening pid dicts...')
    df_avg = df['average_pid'].apply(pd.Series)
    df_avg.columns = pd.MultiIndex.from_product([['average_pid'], df_avg.columns])

    df_inf = df['informative_pid'].apply(pd.Series)
    df_inf.columns = pd.MultiIndex.from_product([['informative_pid'], df_inf.columns])

    df_mis = df['misinformative_pid'].apply(pd.Series)
    df_mis.columns = pd.MultiIndex.from_product([['misinformative_pid'], df_mis.columns])
    
    print('Combining results...')
    df = pd.concat(
                [df.drop(['average_pid', 'informative_pid', 'misinformative_pid'], axis=1),
                    df_inf, df_mis, df_avg],
                axis=1,
            )

    print('Writing results to hdf5')
    df.to_hdf(full_path + measurements_subdir + '.h5', key='df', mode='w')
    print('Done')

def main(argv):

    exp_name = argv[1]
    measurements_subdir = argv[2] if len(argv) > 2 else 'measurements'

    compress_measurements(exp_name, measurements_subdir)

if __name__=='__main__':
    main(sys.argv)