import sys
import logging

orig_stdout = sys.stdout
orig_stderr = sys.stderr


class LoggerWriter:
    def __init__(self, level, orig_stream, log_to_screen):
        # self.level is really like using log.debug(message)
        # at least in my case
        self.level = level
        self.orig_stream = orig_stream
        self.log_to_screen = log_to_screen
        self.buffer = ""

    def write(self, message):
        # if statement reduces the amount of newlines that are
        # printed to the logger
        self.buffer += message
        lines = self.buffer.split('\n')
        if len(lines) > 1:
            if lines[-1] == "":
                self.buffer = lines[-1]
                lines = lines[:-1]
            else:
                self.buffer = ""
            for line in lines:  # type: str
                if line.startswith('\r'):
                    pass
                else:
                    self.level(line)

        if self.log_to_screen:
            self.orig_stream.write(message)

    def flush(self):
        if self.log_to_screen:
            self.orig_stream.flush()

    def close(self):
        pass


def setup_logging(log_path, log_to_screen=True):
    # Clear current logging handlers
    logging.getLogger().handlers.clear()
    logging.getLogger().setLevel(logging.DEBUG)

    log_format = '%(asctime)s %(message)s'

    # Add console print handler
    print_level = logging.WARN
    sh = logging.StreamHandler(stream=sys.stdout)
    sh.setFormatter(logging.Formatter(fmt=log_format, datefmt='%m-%d_%H:%M:%S'))
    sh.setLevel(print_level)
    #logging.getLogger().addHandler(sh)

    # Add file handler
    fh = logging.FileHandler(log_path)
    fh.setFormatter(logging.Formatter(log_format))
    fh.setLevel(logging.DEBUG)
    logging.getLogger().addHandler(fh)

    # Redirect system streams
    sys.stdout = LoggerWriter(logging.debug, orig_stdout, log_to_screen)
    sys.stderr = LoggerWriter(logging.debug, orig_stderr, log_to_screen)
