import os
import pickle
from torch.utils.tensorboard import SummaryWriter
from omegaconf import DictConfig, ListConfig

def convert(d):
    if isinstance(d, ListConfig) or isinstance(d, list):
        return [convert(e) for e in d]
    elif isinstance(d, DictConfig) or isinstance(d, dict):
        return {k: convert(v) for k, v in d.items()}
    else:
        return d

class Logger:
    def get_logger(self, name):
        raise NotImplementedError

    def add_scalar(self, name, value, epoch):
        raise NotImplementedError

class TensorBoardLogger(Logger):
    def __init__(self, log_dir, prefix=None, max_cache_size=1000):
        self._log_dir = log_dir
        self._logger = None
        self._prefix = prefix
        self._cache = []
        self._max_cache_size = max_cache_size

    @staticmethod
    def log_params(log_dir, params):
        params=convert(params)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        with open(log_dir + "/params.pickle", "wb") as file:
            pickle.dump(params, file)

    def add_scalar(self, name, value, epoch):
        if self._logger is None:
            self._logger = SummaryWriter(log_dir=self._log_dir)
        if not self._prefix is None:
            name = self._prefix + "/" + name
        self._logger.add_scalar(name, value, epoch)
        self._cache.append((name, value, epoch))
        if len(self._cache) > self._max_cache_size:
            self._flush()

    def add_histogram(self, name, values, epoch, bins='auto'):
        if self._logger is None:
            self._logger = SummaryWriter(log_dir=self._log_dir)
        if not self._prefix is None:
            name = self._prefix + "/" + name
        self._logger.add_histogram(name, values, epoch, bins)
        self._cache.append((name, values, epoch, bins))
        if len(self._cache) > self._max_cache_size:
            self._flush()

    def _flush(self):
        if len(self._cache) == 0:
            return
        with open(self._logger.log_dir + "/values.pickle", "ab") as file:
            pickle.dump(self._cache, file)
        self._cache = []

    def close(self):
        self._flush()

    def __del__(self):
        print("Flushing log at closing time: ",len(self._cache))
        self._flush()