from abc import ABC, abstractmethod
from acme.utils import loggers as acme_loggers
import neptune.new as neptune
import numpy as np
import tensorflow as tf
from typing import Any, Dict, List, Mapping, Union
import yaml

# An artifact is a mapping between a string and the path of a file to log
Path = str
Artifact = Mapping[str, Union[Path, Any]]  # recursive structure


class PoppyLogger(ABC):
    @abstractmethod
    def write(self, data: acme_loggers.LoggingData) -> None:
        pass

    @abstractmethod
    def write_config(self, config: Dict, name="config") -> None:
        pass

    @abstractmethod
    def write_artifact(self, artifact: Artifact) -> None:
        pass

    @abstractmethod
    def close(self):
        pass


class TerminalLogger(acme_loggers.TerminalLogger, PoppyLogger):
    def __init__(self, label: str, time_delta: float, **kwargs: Any):
        super(TerminalLogger, self).__init__(label=label, time_delta=time_delta, print_fn=print, **kwargs)

    def write_config(self, config: Dict, name="config") -> None:
        self._print_fn(f"{name.capitalize()}:\n {yaml.dump(config)}")

    def write_artifact(self, artifact: Artifact) -> None:
        pass


class NeptuneLogger(acme_loggers.Logger, PoppyLogger):
    def __init__(self, **kwargs: Any):
        super(NeptuneLogger, self).__init__()
        self.run = neptune.init(**kwargs)

    def write(self, data: acme_loggers.LoggingData) -> None:
        for key, value in data.items():
            key_with_label = f"{key}"
            if not np.isscalar(value):
                if isinstance(value, tf.Tensor):
                    value = tf.keras.backend.get_value(value)
                value = float(value)
            self.run[key_with_label].log(value)

    def write_config(self, config: Dict, name="config") -> None:
        self.run[name] = config

    def write_artifact(self, artifact: Artifact) -> None:
        for key, value in artifact.items():
            self.run[key].upload(value)

    def close(self) -> None:
        self.run.stop()


class EnsembleLogger(acme_loggers.Logger, PoppyLogger):
    def __init__(self, loggers: List[PoppyLogger]):
        self.loggers = loggers

    def write(self, data: acme_loggers.LoggingData) -> None:
        for logger in self.loggers:
            logger.write(data)

    def write_config(self, config: Dict, name="config") -> None:
        for logger in self.loggers:
            logger.write_config(config, name)

    def write_artifact(self, artifact: Artifact) -> None:
        for logger in self.loggers:
            logger.write_artifact(artifact)

    def close(self) -> None:
        for logger in self.loggers:
            logger.close()
