import os
import os.path as osp
import sys

import tensorboard as tb
import numpy as np
import pandas as pd

from typing import Dict, List, Tuple

from collections import defaultdict, namedtuple
from packaging import version

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

from expground.logger import Log


major_ver, minor_ver, _ = version.parse(tb.__version__).release
assert (
    major_ver >= 2 and minor_ver >= 2
), f"This notebook requires TensorBoard 2.3 or later. {(major_ver, minor_ver)}"
Log.info("TensorBoard version: {}".format(tb.__version__))


def tabulate_event(dpath: str) -> Dict[str, List[Tuple]]:
    """Read scalar array from a event file with give `dpath`.

    Args:
        dpath (str): File path.

    Returns:
        Dict[str, List[Tuple]]: A dict of list, mapping from key to an array
    """

    summary_iterator = EventAccumulator(dpath).Reload()

    tags = summary_iterator.Tags()["scalars"]
    res = {}
    for tag in tags:
        res[tag] = []
        for scalar in summary_iterator.Scalars(tag):
            res[tag].append((scalar.wall_time, scalar.value, scalar.step))
        res[tag].sort(key=lambda x: x[-1])

    return res


def get_file_path(dir_path: str, tag: str) -> str:
    """Concates directory path and tag as a csv file path.

    This function will parse the given tag and replace all slash as `_`.
    Then concates `dir_path` and filtered tag as new file path.

    Args:
        dir_path (str): Directory path.
        tag (str): Tag name.

    Returns:
        str: A new csv file path.
    """

    file_name = tag.replace("/", "_") + ".csv"
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    return os.path.join(dir_path, file_name)


def to_csv(event_path: str, out_path, filter=None):
    res: Dict = tabulate_event(event_path)
    tags, values = zip(*res.items())

    for index, tag in enumerate(tags):
        if filter and filter not in tag:
            continue
        np_value = np.array(values[index], dtype=np.float32)
        df = pd.DataFrame(
            np_value[:, :2],
            index=np_value[:, -1].astype(np.int32),
            columns=["wall_time", "value"],
        )
        csv_path = get_file_path(out_path, tag)
        df.to_csv(csv_path, index_label="index")
        print(f"\t* converted tag={tag} to csv={csv_path}")


def walk(source_dir: str, filter=None):
    """Walk through a given root_dir to retrive all event files, and return the parsed data"""

    assert osp.exists(
        source_dir
    ), f"Path `{source_dir}` doesn't exist, please check it again"

    for (dir_path, dir_names, file_names) in os.walk(source_dir):
        for file_name in file_names:
            if file_name.startswith("events"):
                event_path = osp.join(dir_path, file_name)
                yield event_path
                # print(f"\n----- parsing event: {event_path}")
                # to_csv(
                #     event_path, out_path=osp.join(dest_dir, file_name), filter=filter
                # )


def load_csv_to_pd(root_dir: str, filter=None) -> Dict[str, pd.DataFrame]:
    """Merge csv files under `root_dir` as a dict of DataFrames.

    Args:
        root_dir (str): Directory path of csv files.
        filter ([type], optional): [description]. Defaults to None.

    Returns:
        Dict[str, pd.DataFrame]: A dict of dataframes
    """
    pd_dict = {}
    for (dir_path, dir_names, file_names) in os.walk(root_dir):
        for file_name in file_names:
            if file_name.endswith(".csv"):
                csv_path = osp.join(dir_path, file_name)
                Log.info(f"* found one csv file: {csv_path}")
                pd_dict[file_name] = pd.read_csv(csv_path)
    return pd_dict


# if __name__ == "__main__":
#     # Walk through a directory to find all events, and export filtered events to csv files (one csv for one tag
#     #   tags from one event file will be saved under a same folder).
#     # usage: ./tools.py {event_dir} {csv_dir} {event_tag_wild_filter}
#     # for example: python3 ./plot/tools.py ./data/events/MPE/share/maddpg/ ./csv/maddpg/simple_spread reward
#     args = sys.argv[1:]
#     if len(args) < 3:
#         args.append(None)
#     walk(args[0], args[1], args[2])
