from torch.utils.tensorboard import SummaryWriter
from pipeline.registry import registry
import os
import collections.abc
import torch

def deep_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = deep_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d

@registry.register_utils("tensorboard_logger")
class TensorboardLogger(object):
    def __init__(self, cfg):
        log_dir = cfg['logger']['args']['log_dir']
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        self.writer = SummaryWriter(log_dir)

    def log(self, log_dict, step=None):
        for k, v in log_dict.items():
            self.writer.add_scalar(k, v, step)

    def log_image(self, tag, img_tensor, step=None):
        
        if img_tensor is None:
            return

        if isinstance(img_tensor, torch.Tensor):
            if img_tensor.dim() == 3 and img_tensor.shape[-1] in [1, 3, 4]:
                
                img_tensor = img_tensor.permute(2, 0, 1)
            if img_tensor.max() > 1.0:
                img_tensor = img_tensor.float() / 255.0

        self.writer.add_image(tag, img_tensor, step)

    def log_histogram(self, tag, values, step=None):
        
        if values is None:
            return
        if isinstance(values, torch.Tensor):
            values = values.detach().cpu()
        self.writer.add_histogram(tag, values, step)

    def log_figure(self, tag, figure, step=None):
        
        if figure is None:
            return
        self.writer.add_figure(tag, figure, step)

    def flush(self):
        
        self.writer.flush()

    def close(self):
        
        self.writer.close()
