from torch.utils.tensorboard import SummaryWriter
import shutil
import os
from collections import defaultdict


from fnmatch import fnmatchcase as match


class TBWritter(object):
    def __init__(self, log_dir, cover=True):
        self.log_dir = log_dir

        if cover and os.path.exists(log_dir):
            for filename in os.listdir(log_dir):
                if match(filename, 'events.out.*'):
                    os.remove(os.path.join(log_dir, filename))
#             shutil.rmtree(log_dir, ignore_errors=True)

        self._log_writer = SummaryWriter(log_dir)
        
        self._step = 0
        
    def append(self, type='train', update=True, **kwargs):
        if update:
            self._step += 1
        
        for name, value in kwargs.items():
            self._log_writer.add_scalar('{}/{}'.format(type, name), value, self._step)


class TBWritter2(object):
    def __init__(self, log_dir, cover=True, step_counter: dict = None):
        self.log_dir = log_dir

        if cover and os.path.exists(log_dir):
            for filename in os.listdir(log_dir):
                if match(filename, 'events.out.*'):
                    os.remove(os.path.join(log_dir, filename))

        self._log_writer = SummaryWriter(log_dir)

        if step_counter is None:
            step_counter = defaultdict(lambda: 0)
            step_counter['train'] = 0
        self.step_counter = step_counter

    def append(self, type='train', update=True, counter='train', **kwargs):
        if update:
            self.step_counter[counter] += 1

        for name, value in kwargs.items():
            self._log_writer.add_scalar('{}/{}'.format(type, name), value, self.step_counter[counter])
