from typing import Text, Dict, Optional, Union

import os

import matplotlib
from matplotlib import rc
import matplotlib.pyplot as plt

import numpy as np

Results = Dict[Text, Union[float, np.ndarray]]


# ==============================================================================
# MATPLOTLIB SETTINGS
# ==============================================================================
matplotlib.rcdefaults()

matplotlib_rc = {
  'text': {'usetex': True},
  'font': {'size': '16', 'family': 'serif', 'serif': 'Palatino'},
  'figure': {'titlesize': '20', 'figsize': (8, 5)},
  'axes': {'titlesize': '20', 'labelsize': '18'},
  'legend': {'fontsize': '18'},
  'xtick': {'labelsize': '18'},
  'ytick': {'labelsize': '18'},
  'lines': {'linewidth': 3, 'markersize': 10},
  'grid': {'color': 'grey', 'linestyle': 'solid', 'linewidth': 0.5},
}

for k, v in matplotlib_rc.items():
  rc(k, **v)


# ==============================================================================
# DECORATORS & GENERAL FUNCTIONALITY
# ==============================================================================
# There's nothing worse than the entire run failing because something is wrong
# with a plot that may not even be that important. Therefore, if something goes
# wrong in a plot, we simply return an empty figure and continue.
def empty_fig_on_failure(func):
  """Decorator for individual plot functions to return empty fig on failure."""
  def applicator(*args, **kwargs):
    # noinspection PyBroadException
    try:
      return func(*args, **kwargs)
    except Exception:  # pylint: disable=bare-except
      return plt.figure()
  return applicator


def save_plot(figure: plt.Figure, path: Text):
  """Store a figure in a given location on disk."""
  if path is not None:
    figure.savefig(path, bbox_inches="tight", format="pdf")
    plt.close(figure)


# ==============================================================================
# PLOTTING FUNCTIONS
# ==============================================================================

@empty_fig_on_failure
def plot_regret(results: Results) -> plt.Figure:
  fig = plt.figure()
  plt.plot(np.cumsum(results['regret']))
  plt.xlabel('rounds')
  plt.ylabel('regret')
  return fig


@empty_fig_on_failure
def plot_return(results: Results) -> plt.Figure:
  fig = plt.figure()
  plt.plot(np.cumsum(results['return']))
  plt.xlabel('rounds')
  plt.ylabel('return')
  return fig


# ==============================================================================
# FUNCTION TO PLOT EVERYTHING
# ==============================================================================

def plot_all(results, base_dir: Optional[Text] = None, suffix: Text = ""):
  """Call all relevant plotting functions.

  Args:
      results: The dictionary containing all results from the run.
      base_dir: The path where to store the figures. If `None` don't save the
            figures to disk.
      suffix: An optional suffix to each filename stored by this function.
  """
  if base_dir is not None and not os.path.exists(base_dir):
    os.makedirs(base_dir)

  def get_filename(base: Text, fname: Text):
    return None if base is None else os.path.join(base, fname)

  suff = "_" + suffix if suffix else suffix

  # Provide a name for the plot of each plotting function
  name = "regret{}.pdf".format(suff)
  save_plot(plot_regret(results), get_filename(base_dir, name))

  name = "return{}.pdf".format(suff)
  save_plot(plot_return(results), get_filename(base_dir, name))
