# -*- coding: utf-8 -*-
from __future__ import print_function

import argparse
import os
import traceback
from datetime import datetime

import numpy as np
import pandas as pd
import sys
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from tensorboard.backend.event_processing.event_file_inspector import get_inspection_units
from tensorboard.plugins.hparams import plugin_data_pb2


def parse_arguments():
    parser = argparse.ArgumentParser(description='Extract data from Tensorboard logs to pandas and save')
    parser.add_argument('--force', help='Forces extraction even if up-to-date file exists.', action='store_true')
    parser.add_argument('--dir', help='Log Folder', default='../logs')
    global args
    args = parser.parse_args()


def get_hparams(acc: EventAccumulator):
    """
    Loops over all events in the given directory and looks for a hparams
    :param acc:
    :return: None if None was found or
    """

    row = {}
    # get all hparams contained in the acc session_start_info is there is any
    if '_hparams_/session_start_info' in acc.summary_metadata:
        data = acc.SummaryMetadata('_hparams_/session_start_info').plugin_data.content
        pdata = plugin_data_pb2.HParamsPluginData.FromString(data)
        if pdata.session_start_info.hparams:
            for k in pdata.session_start_info.hparams.keys():
                row[k] = eval(str(pdata.session_start_info.hparams[k]).split(':')[1].strip().capitalize())

    return row


def main():
    root_dir = args.dir
    output_file = 'tb_data.csv'
    if not os.path.isdir(root_dir):
        print(root_dir + " does not exist!")
        return

    tb_data_filename = os.path.join(root_dir, output_file)
    if os.path.isfile(tb_data_filename):
        df_tb_data = pd.read_csv(tb_data_filename, index_col=[0])
    else:
        df_tb_data = pd.DataFrame()

    print(f"Get inspection units from {root_dir}")
    inspect_units = get_inspection_units(logdir=root_dir)

    if not inspect_units:
        print("No inspection units found in " + root_dir)
        return

    _extracted_tags = []
    for run in inspect_units:
        path = os.path.relpath(run.name)

        if path not in df_tb_data.index or args.force:

            output = {path: {}}

            try:
                acc = EventAccumulator(run.name)
                acc.Reload()
                _hparams = get_hparams(acc)

                # Skip if no hparams or no scalars
                if not _hparams or len(acc.Tags()['scalars']) == 0:
                    continue

                output[path] = _hparams
                output[path]['date'] = datetime.fromtimestamp(acc.FirstEventTimestamp())
                # Some tags are logged each training sample
                # Therefore the tags relevant for the total epochs are hardcoded here
                tags_for_epoch_count = ['epoch']
                output[path]['total epochs'] = max(
                    [acc.Scalars(tag)[-1].value for tag in tags_for_epoch_count])

                last_timestamp = acc.FirstEventTimestamp()
                for tag in acc.Tags()['scalars']:
                    scalar = acc.Scalars(tag)
                    last_timestamp = max(last_timestamp, scalar[-1].wall_time)
                    _all_values = np.array([s.value for s in scalar][1:])
                    if len(_all_values):
                        output[path][tag + ' last'] = _all_values[-1]
                        output[path][tag + ' min'] = min(_all_values)
                        output[path][tag + ' min epoch'] = np.where(_all_values == min(_all_values))[0][0]
                        if tag not in _extracted_tags:
                            _extracted_tags.append(tag)

                output[path]['end_date'] = datetime.fromtimestamp(last_timestamp)

                # Add to the df and save it
                df_tb_data = pd.concat([df_tb_data, pd.DataFrame(output).T])
            except Exception:
                print("Error while parsing {}:".format(path))
                print(traceback.format_exc())
                # TODO: write erronous paths to file

    if not df_tb_data.empty:
        df_tb_data.to_csv(tb_data_filename)

    print()
    print("All finished.")


if __name__ == '__main__':
    parse_arguments()
    sys.exit(main())
