from datetime import datetime
from os import PathLike
from pathlib import Path
from typing import Dict, Tuple, Union

from tensorboardX import SummaryWriter


class TensorboardAdapter():
    def __init__(self, work_dir) -> None:
        assert isinstance(work_dir, Path)
        logdir = work_dir / "logs"

        self.groups = {}
        self.writer = SummaryWriter(log_dir=work_dir / "logs")

        self.aux_writer = {}

    def log(self, name : str, value : float, step : int = None):
        self.writer.add_scalar(name, value, global_step=step)

    def log_param(self, name : str, value : float, step : int = None):
        # self.writer.add_scalar(name, value, global_step=step)
        pass

    def log_distrib(self, name : str, value : float, step : int = None):
        self.writer.add_histogram(name, value, global_step=step)

    def dump(self, *args, **kwargs):
        pass

    def log_histogram(self, name : str, value : float, step : int = None):
        if step % 10000 == 0:
            self.writer.add_histogram(name, value, global_step=step)

    def register_group(self, name : str, format : Tuple[str, str, str], **kwargs):
        self.groups[name] = format

    def log_data(self, group : str, value : Dict[str, Union[float, int]], step=None):
        # headers = self.groups[group]
        step = step or value.get("epoch", value.get("iteration", None))
        for k, v in value.items(): 
            self.writer.add_scalar(f"{group}/{k}", v, step)

    def log_aux(self, collection, data=[], flush=False):
        assert isinstance(collection, str)

        if isinstance(data, list):
            data = ", ".join(map(str, data))

        if collection not in self.aux_writer:
            self.aux_writer[collection] = Path(f"aux_{collection}").open("a")
            print(f"Started writing at {datetime.now()}", file=self.aux_writer[collection])

        print(data, file=self.aux_writer[collection], flush=flush)
