import os
import shutil

from graph_learning.utils import flatten_dict
import pandas as pd

class Logger(object):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self._sub_loggers = {}

    def register(self, name, root_dir=None, builder=None, options={}):
        if builder is None:
            builder = CommonLogger
        if root_dir is None:
            root_dir = os.path.join(self.root_dir, name)
        sub_logger = builder(root_dir=root_dir, **options)
        sub_logger.init()
        assert name not in self._sub_loggers
        self._sub_loggers[name] = sub_logger
        return sub_logger

    def __getitem__(self, key):
        if key in self._sub_loggers:
            logger = self._sub_loggers[key]
        else:
            logger = self.register(key)
        return logger

    def init(self, exist_ok=True):
        os.makedirs(self.root_dir, exist_ok=exist_ok)

    def purge(self):
        shutil.rmtree(self.root_dir, ignore_errors=True)

class CommonLogger(Logger):
    def __init__(self, root_dir):
        super().__init__(root_dir)

    def copy_file(self, src, name):
        try:
            shutil.copyfile(src, os.path.join(self.root_dir, name))
        except shutil.SameFileError:
            pass

    def path(self, name):
        return os.path.join(self.root_dir, name)

    def log(self, name=None, content='', postfix=True, slient=False, append=True):
        content = str(content)
        if content != '':
            content = content+'\n'
        if not slient:
            print(content, end='')
        if name is not None:
            log_file = os.path.join(self.root_dir, name+('.log' if postfix else '') )
            mode = 'a' if append else 'w'
            with open(log_file, mode) as f:
                f.write(content)

    def log_csv(self, name, df):
        if isinstance(df, dict):
            df = pd.DataFrame(flatten_dict(df), index=[0])
        path = os.path.join(self.root_dir, f'{name}.csv')
        df.to_csv(path, index=False)

    def load_csv(self, name):
        path = os.path.join(self.root_dir, f'{name}.csv')
        df = pd.load_csv(path)
        return df

from tensorboardX import SummaryWriter

class TBLogger(Logger):
    def __init__(self, root_dir):
        super().__init__(root_dir)
        self.writer = SummaryWriter(log_dir=root_dir)

    def __getattr__(self, name):
        return getattr(self.writer, name)
