from typing import Type, Dict

from matplotlib.figure import Figure

from compute_result.output_manager.base import OutputManager
from compute_result.output_manager.writer_factory import WriterFactory
from compute_result.typing import Run, BudgetRun


class TensorBoardOutput(OutputManager):
    def __init__(self, writer_factory: Type[WriterFactory]):
        self.writer_factory = writer_factory
        self.writer = None

    def log_normalize_progression(self, results: BudgetRun, plot_name: str, graph_name: str):
        writer = self.writer_factory.create(plot_name)
        for i, function_value in results:
            writer.add_scalar(graph_name, function_value, i)
        writer.close()

    def finish(self):
        pass

    def print_image(
        self,
        run: Run,
        fig: Figure,
        index: int = None,
        graph_name: str = "",
        plot_name: str = "",
    ):
        algorithm, run_name = run
        if self.writer is None:
            self.writer = self.writer_factory.create(
                plot_name or f"{algorithm.value}_{run_name}"
            )
        graph_name = (
            f"{algorithm.value} {run_name} {graph_name}"
            if graph_name
            else f"{algorithm.value} {run_name}"
        )
        self.writer.add_figure(graph_name, fig, global_step=index)

    def print_bars(self, bars: Dict[str, float], graph_name: str = ""):
        raise NotImplementedError()
