from typing import Optional

import numpy as np
from torch.utils.tensorboard import SummaryWriter


class TensorboardLogger:
    def __init__(
            self,
            log_dir=None,
            comment="",
            purge_step=None,
            max_queue=10,
            flush_secs=120,
            filename_suffix=""
    ) -> None:
        assert log_dir is None or isinstance(log_dir, str), f'log_dir must be a string or None: {log_dir}'
        self.writer: Optional[SummaryWriter] = SummaryWriter(
            log_dir=log_dir,
            comment=comment,
            purge_step=purge_step,
            max_queue=max_queue,
            flush_secs=flush_secs,
            filename_suffix=filename_suffix
        ) if log_dir is not None else None

    def log_values(self, step: int, values: dict[str, any], prefix: str = '') -> None:
        if self.writer is not None:
            assert isinstance(step, int), f'step must be an integer: {step}'
            assert isinstance(values, dict), f'values must be a dictionary: {values}'
            assert all(isinstance(key, str) for key in values.keys()), f'all keys must be strings: {values.keys()}'
            assert all(isinstance(value, (float, int, bool, dict)) for value in values.values()), \
                f'all values must be floats, ints, bools, or dictionaries: {values.values()}'
            assert isinstance(prefix, str), f'prefix must be a string: {prefix}'
            for key in values.keys():
                if isinstance(values[key], float) or isinstance(values[key], int) or isinstance(values[key], bool):
                    self.writer.add_scalar(
                        f'{prefix}{key}',
                        float(values[key]),
                        step
                    )
                elif isinstance(values[key], dict):
                    self.log_values(step, values[key], prefix=f'{prefix}{key}/')
                else:
                    raise ValueError(f'unsupported type {type(values[key])} for {key}')

    def log_images(self, tag: str, images: np.ndarray, step: int) -> None:
        if self.writer is not None:
            log_images: np.ndarray = (np.round(np.clip(images * 0.5 + 0.5, 0, 1) * 255)).astype(np.uint8)
            self.writer.add_images(tag, log_images, step)

    def close(self):
        if self.writer is not None:
            self.writer.close()
