from datetime import datetime
from pathlib import Path
import logging
import re
import pickle
import numpy as np
import tqdm

class TqdmHandler(logging.Handler):
    def __init__(self):
        super().__init__(logging.NOTSET)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except (KeyboardInterrupt, SystemExit):
            raise
        except Exception:
            self.handleError(record)


class ConsoleLogger(object):
    """
    This class implements the console logging functionality. It can be used to
    log text into the console and optionally save a log file.

    """
    def __init__(self, log_name, log_dir=None, suffix='',
                 log_file_name=None,
                 console_log_level=logging.DEBUG,
                 file_log_level=logging.DEBUG):
        """
        Constructor.

        Args:
            log_name (str, None): Name of the current logger.
            log_dir (Path, None): path of the logging directory. If None, no
                the console output is not logged into a file;
            suffix (int, None): optional string to add a suffix to the logger id
                and to the data file logged;
            log_file_name (str, None): optional specifier for log file name,
                id is used by default;
            console_log_level (int, logging.DEBUG): logging level for console;
            file_log_level (int, logging.DEBUG): logging level for file.

        """
        self._log_id = log_name + suffix

        formatter = logging.Formatter(fmt='%(asctime)s [%(levelname)s] %(message)s',
                                      datefmt='%d/%m/%Y %H:%M:%S')

        self._logger = logging.getLogger(self._log_id)
        self._logger.setLevel(min(console_log_level, file_log_level))
        self._logger.propagate = False
        ch = TqdmHandler()
        ch.setLevel(console_log_level)
        ch.setFormatter(formatter)
        self._logger.addHandler(ch)

        if log_dir is not None:
            log_file_name = self._log_id if log_file_name is None else log_file_name
            log_file_name += '.log'
            log_file_path = log_dir / log_file_name
            fh = logging.FileHandler(log_file_path)
            fh.setLevel(file_log_level)
            fh.setFormatter(formatter)
            self._logger.addHandler(fh)

    def debug(self, msg):
        """
        Log a message with DEBUG level

        """
        self._logger.debug(msg)

    def info(self, msg):
        """
        Log a message with INFO level

        """
        self._logger.info(msg)

    def warning(self, msg):
        """
        Log a message with WARNING level

        """
        self._logger.warning(msg)

    def error(self, msg):
        """
        Log a message with ERROR level

        """
        self._logger.error(msg)

    def critical(self, msg):
        """
        Log a message with CRITICAL level

        """
        self._logger.critical(msg)

    def exception(self, msg):
        """
        Log a message with ERROR level. To be called
        only from an exception handler

        """
        self._logger.exception(msg)

    def strong_line(self):
        """
        Log a line of #

        """
        self.info('###################################################################################################')

    def weak_line(self):
        """
        Log a line of -

        """
        self.info('---------------------------------------------------------------------------------------------------')

    def epoch_info(self, epoch, **kwargs):
        """
        Log the epoch info with the format: Epoch <epoch number> | <label 1>: <data 1> <label 2> <data 2> ...

        Args:
            epoch (int): epoch number;
            **kwargs: the labels and the data to be displayed.

        """
        msg = 'Epoch ' + str(epoch) + ' |'

        for name, data in kwargs.items():
            msg += ' ' + name + ': ' + str(data)

        self.info(msg)

    def __del__(self):
        self._logger.handlers.clear()
class DataLogger(object):
    """
    This class implements the data logging functionality. It can be used to create
    automatically a log directory, save numpy data array and the current agent.

    """
    def __init__(self, results_dir, suffix='', append=False):
        """
        Constructor.

        Args:
            results_dir (Path): path of the logging directory;
            suffix (string): optional string to add a suffix to each
                data file logged;
            append (bool, False): If true, the logger will append the new
                data logged to the one already existing in the directory.

        """
        self._results_dir = results_dir
        self._suffix = suffix
        self._data_dict = dict()

        self._best_J = -np.inf

        if append:
            self._load_numpy()

    def log_numpy(self, **kwargs):
        """
        Log scalars into numpy arrays.

        Args:
            **kwargs: set of named scalar values to be saved. The argument name
                will be used to identify the given quantity and as base file name.

        """
        for name, data in kwargs.items():
            if name not in self._data_dict:
                self._data_dict[name] = list()

            self._data_dict[name].append(data)

            filename = name + self._suffix + '.npy'
            path = self._results_dir / filename

            current_data = np.array(self._data_dict[name])
            np.save(path, current_data)

    def log_agent(self, agent, epoch=None, full_save=False):
        """
        Log agent into the log folder.

        Args:
            agent (Agent): The agent to be saved;
            epoch (int, None): optional epoch number to
                be added to the agent file currently saved;
            full_save (bool, False): whether to save the full
                data from the agent or not.

        """
        epoch_suffix = '' if epoch is None else '-' + str(epoch)

        filename = 'agent' + self._suffix + epoch_suffix + '.msh'
        path = self._results_dir / filename
        agent.save(path, full_save=full_save)

    def log_best_agent(self, agent, J, full_save=False):
        """
        Log the best agent so far into the log folder. The agent
        is logged only if the current performance is better
        than the performance of the previously stored agent.

        Args:
            agent (Agent): The agent to be saved;
            J (float): The performance metric of the current agent;
            full_save (bool, False): whether to save the full
                data from the agent or not.

        """

        if J >= self._best_J:
            self._best_J = J

            filename = 'agent' + self._suffix + '-best.msh'
            path = self._results_dir / filename
            agent.save(path, full_save=full_save)

    def log_dataset(self, dataset):
        filename = 'dataset' + self._suffix + '.pkl'
        path = self._results_dir / filename

        with path.open(mode='wb') as f:
            pickle.dump(dataset, f)

    @property
    def path(self):
        """
        Property to return the path to the current logging directory

        """
        return self._results_dir

    def _load_numpy(self):
        for file in self._results_dir.iterdir():
            if file.is_file() and file.suffix == '.npy':
                if file.stem.endswith(self._suffix):
                    name = re.split(r'-\d+$', file.stem)[0]
                    data = np.load(str(file)).tolist()
                    self._data_dict[name] = data


class Logger(DataLogger, ConsoleLogger):
    """
    This class implements the logging functionality. It can be used to create
    automatically a log directory, save numpy data array and the current agent.

    """

    def __init__(
        self,
        log_name="",
        results_dir="./logs",
        log_console=False,
        use_timestamp=False,
        append=False,
        seed=None,
        **kwargs,
    ):
        """
        Constructor.

        Args:
            log_name (string, ''): name of the current experiment directory if not
                specified, the current timestamp is used.
            results_dir (string, './logs'): name of the base logging directory.
                If set to None, no directory is created;
            log_console (bool, False): whether to log or not the console output;
            use_timestamp (bool, False): If true, adds the current timestamp to
                the folder name;
            append (bool, False): If true, the logger will append the new
                data logged to the one already existing in the directory;
            seed (int, None): seed for the current run. It can be optionally
                specified to add a seed suffix for each data file logged;
            **kwargs: other parameters for ConsoleLogger class.

        """

        if log_console:
            assert results_dir is not None

        timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

        if not log_name:
            log_name = timestamp
        elif use_timestamp:
            log_name += "_" + timestamp

        if results_dir:
            results_dir = Path(results_dir) / log_name

            print("Logging in folder: " + str(results_dir))
            results_dir.mkdir(parents=True, exist_ok=True)

        suffix = "" if seed is None else "-" + str(seed)

        DataLogger.__init__(self, results_dir, suffix=suffix, append=append)
        ConsoleLogger.__init__(
            self,
            log_name,
            results_dir if log_console else None,
            suffix=suffix,
            **kwargs,
        )

    def log_best_agents_overall(self, agents, J, full_save=False):
        """
        Log the best agents so far into the log folder. The agents
        are logged only if the current overall performance (e.g. J, success rate)
        is better than the performance of the previously stored agents.

        Args:
            agents (Agents): The agents to be saved;
            J (float): The performance metric of the current agent;
            full_save (bool, False): whether to save the full
                data from the agent or not.

        """

        if J >= self._best_J:
            self._best_J = J

            for i, agent in enumerate(agents):
                filename = "agent" + self._suffix + f"_{i}" + "-best.msh"
                path = self._results_dir / filename
                agent.save(path, full_save=full_save)