################################################################################
# experilog/logger.py
#
# 
# 
# 
#
# Logging system for experiments.

import json
import numpy    as np
import platform
import sys

from datetime     import datetime
from numpy.typing import NDArray
from os.path      import isdir, isfile, join
from torch        import __version__ as torch_version
from typing       import Dict, List, Optional, Union
from warnings     import warn

from .measure import Measurement

JSONType = Union[
  str, int, float, bool, List["JSONType"], Dict[str, "JSONType"], None
]
ListArray = List[Union[int, float, bool, "ListArray"]]

TITLE           = "experiment"
DATETIME_FORMAT = "%Y%m%dT%H%M%SZ"
# Function composition would do wonders here.
ARRAY_SUMMARY_MAP = {
  "mean":       lambda x: float(np.mean(x)),
  "pop_std":    lambda x: float(np.std(x, ddof = 0)),
  "sample_std": lambda x: float(np.std(x, ddof = 1)),
  "min":        lambda x: float(np.min(x)),
  "max":        lambda x: float(np.max(x)),
  "median":     lambda x: float(np.median(x)),
  "iqr":        lambda x: float(np.quantile(x, 0.75) - np.quantile(x, 0.25)),
  "shape":      lambda x: list(int(y) for y in x.shape),
  "size":       lambda x: int(x.size)
}

# Experiment status.
WAITING = 0
RUNNING = 1
STOPPED = 2

def format_datetime(
    # Arguments:
    date
  ):
  return date.strftime(DATETIME_FORMAT)

class Logger():

  def __init__(self,
      # Arguments:
      directory: str,
      # Keyword Arguments:
      title: Optional[str] = None
    ) -> None:
    """
    Initializes ``Logger``.

    Args:
      directory (str):
        The directory to save the experiment log in.
    """
    # Directory.
    assert isinstance(directory, str), \
      "directory must be a string."
    assert isdir(directory), \
      f"{directory} is not a valid directory."
    self.directory = directory
    # Title.
    assert isinstance(title, str) or title is None, \
      "title must be a string or None."
    self.title = TITLE if title is None else title
    self._status = WAITING

  def get_time(self) -> str:
    """
    Gets the current time (in UTC) as an ISO 8601 undashed string.
    """
    return datetime.utcnow()

  def start_experiment(self) -> None:
    """
    Records the date and time (in UTC) of experiment start.
    """
    if self._status != WAITING:
      warn(
        "Experiment has been previously started and/or stopped. Create a " + \
        "new experiment."
      )
      return
    self._experiment_start = self.get_time()
    self._status = RUNNING
    self._results = {}

  def stop_experiment(self) -> None:
    """
    Records the date and time (in UTC) of experiment end.
    """
    if self._status != RUNNING:
      warn(
        "Experiment is not currently running, and so cannot be stopped."
      )
      return
    self._experiment_end = self.get_time()
    self._status = STOPPED

  def record_result(self,
      # Arguments:
      results: JSONType
    ):
    assert self._status == RUNNING, \
      "Cannot record result unless experiment is running."
    result_time = self.get_time()
    assert result_time not in self._results, \
      "Result recording time already exists in results dictionary. " + \
      "Aborting to prevent overwrite."
    self._results[format_datetime(result_time)] = results

  def get_metadata(self) -> JSONType:
    """
    Gets the metadata for the experiment.

    Returns:
      JSONType:
        The metadata as a JSON dict.
    """
    assert self._status == STOPPED, \
      "Cannot get metadata if experiment has not been stopped."
    metadata = {}
    # Start and end.
    metadata["start"] = format_datetime(self._experiment_start)
    metadata["end"]   = format_datetime(self._experiment_end)
    # Duration.
    d_time = self._experiment_end - self._experiment_start
    duration = Measurement(
      d_time.seconds + 1e-6 * d_time.microseconds,
      "seconds"
    )
    metadata["duration"] = duration.get_json_dict()
    # Device and language information.
    metadata["system"] = {
      "platform":     platform.platform(),
      "release":      platform.release(),
      "version":      platform.version(),
      "architecture": platform.machine(),
      "processor":    platform.processor()
    }
    metadata["language"] = {
      "python": sys.version,
      "numpy":  np.version.version,
      "torch":  torch_version
    }
    return metadata

  def set_controls(self,
      # Arguments:
      controls: JSONType
    ) -> None:
    """
    Sets the controls for the experiment.

    Args:
      controls (JSONType):
        The controls for the experiment.
    """
    assert self._status == WAITING, \
      "Can only set controls before experiment has started."
    self._controls = controls

  def get_full_dict(self) -> JSONType:
    """
    Gets the full dictionary of the experiment.

    Returns:
      JSONType:
        The JSON dict of the whole experiment, including metadata.
    """
    assert self._status == STOPPED, \
      "Cannot get full dictionary if experiment has not been stopped."
    full_dict = {}
    full_dict["metadata"] = self.get_metadata()
    full_dict["controls"] = self._controls
    full_dict["results"]  = self._results
    return full_dict
  
  @property
  def full_filename(self) -> str:
    """
    Gets the full filename (directory + name + datetime).

    Returns:
      str:
        The full filename.
    """
    assert self._status >= RUNNING, \
      "Cannot get experiment start date/time unless experiment has started."
    return join(
      self.directory,
      f"{self.title}_{format_datetime(self._experiment_start)}.json"
    )

  def write(self) -> None:
    """
    Writes the experiment data to the JSON file.
    """
    assert self._status == STOPPED, \
      "Can only write to JSON file if the experiment has been stopped."
    filename = self.full_filename
    assert not isfile(filename), \
      f"File '{filename}' already exists. Aborting to prevent overwrite."
    with open(filename, "w") as file:
      file.write(json.dumps(self.get_full_dict(), indent = 2))

  # Numeric data helpers:

  def from_numpy(self,
        # Arguments:
        array: NDArray
      ) -> ListArray:
      """
      Converts a NumPy array into a (nested) list "array".

      Args:
        array (NDArray):
          The NumPy array.

      Returns:
        ListArray:
          The array represented as nested lists.
      """
      return array.astype(float).tolist()

  def array_summary(self,
      # Arguments:
      array: NDArray,
      # Keyword Arguments:
      include: Optional[List[str]] = None,
      exclude: Optional[List[str]] = None
    ) -> JSONType:
    """
    Generates a numeric summary of a NumPy array.

    Args:
      array (NDArray):
        The NumPy array.
      include (List[str], optional):
        Accumulations to include in the summary.
        Defaults to all possible accumulations if ``None``.
      exclude (List[str], optional):
        Accumulations to exclude in the summary - usually for if ``include``
        is left as ``None``.
        Defaults to ``[]`` if ``None``.

    Returns:
      JSONType:
        The JSON dictionary of the summaries.
    """
    include = [x for x in ARRAY_SUMMARY_MAP] if include is None else include
    exclude = [] if exclude is None else exclude
    accumulators = set(include) - set(exclude)
    summary = {}
    for accumulator in accumulators:
      assert accumulator in ARRAY_SUMMARY_MAP, \
        f"'{accumulator}' is not a valid summary accumulator."
      summary[accumulator] = ARRAY_SUMMARY_MAP[accumulator](array)
    return summary
