"""Plotting functionality."""

from typing import Text, Optional, Dict

import os

from collections import Counter
import numpy as onp

import jax.numpy as np

import matplotlib
from matplotlib import rc
import matplotlib.pyplot as plt

import pdb


# =============================================================================
# MATPLOTLIB STYLING SETTINGS
# =============================================================================

matplotlib.rcdefaults()

rc('text', usetex=True)
rc('font',  size='16', family='serif', serif=['Palatino'])
rc('figure', titlesize='20')  # fontsize of the figure title
rc('axes', titlesize='20')     # fontsize of the axes title
rc('axes', labelsize='18')    # fontsize of the x and y labels
rc('legend', fontsize='18')    # legend fontsize
rc('xtick', labelsize='18')    # fontsize of the tick labels
rc('ytick', labelsize='18')    # fontsize of the tick labels

rc('axes', xmargin=0)
rc('lines', linewidth=3)
rc('lines', markersize=10)
rc('grid', color='grey', linestyle='solid', linewidth=0.5)
titlekws = dict(y=1.0)

FIGSIZE = (9, 6)
data_kwargs = dict(alpha=0.5, s=5, marker='.', c='grey', label="data")


# =============================================================================
# DECORATORS & GENERAL FUNCTIONALITY
# =============================================================================

def empty_fig_on_failure(func):
  """Decorator for individual plot functions to return empty fig on failure."""
  def applicator(*args, **kwargs):
    # noinspection PyBroadException
    try:
      return func(*args, **kwargs)
    except Exception:  # pylint: disable=bare-except
      return plt.figure()
  return applicator


def save_plot(figure: plt.Figure, path: Text):
  """Store a figure in a given location on disk."""
  if path is not None:
    figure.savefig(path, bbox_inches="tight", format="pdf")
    plt.close(figure)


# =============================================================================
# FINAL AGGREGATE RESULTS
# =============================================================================

@empty_fig_on_failure
def plot_final_max_abs_diff(xstar: np.ndarray, maxabsdiff: np.ndarray):
  fig = plt.figure()
  plt.semilogy(xstar, maxabsdiff[:, 0], 'g--x', label="lower", lw=2)
  plt.semilogy(xstar, maxabsdiff[:, 1], 'r--x', label="upper", lw=2)
  plt.xlabel("x")
  plt.ylabel(f"$\max |LHS - RHS|$")
  plt.title(f"Final maximum violation of constraints")
  plt.legend()
  return fig


def plot_final_bounds(xstar: np.ndarray,
                      bounds: np.ndarray,
                      data_xstar: np.ndarray,
                      data_ystar: np.ndarray,
                      y_given_x,
                      base_dir) -> plt.Figure:
  fig = plt.figure()
  plt.plot(xstar, bounds[:, 0], 'g--x', label="lower", lw=2, markersize=10)
  plt.plot(xstar, bounds[:, 1], 'r--x', label="upper", lw=2, markersize=10)
  if data_ystar.ndim > 1:
    data_ystar = data_ystar.mean(0)
  plt.plot(data_xstar, data_ystar, label=f"$E[Y | do(x^*)]$", lw=2)
  plt.plot(data_xstar, y_given_x, label=f"$E[Y | x]$", lw=1)

  def get_limits(vals):
    lo = np.min(vals)
    hi = np.max(vals)
    extend = (hi - lo) / 15.
    return lo - extend, hi + extend

  #plt.xlim(get_limits(x))
  #plt.ylim(get_limits(y))
  plt.xlabel('x')
  plt.ylabel('y')
  plt.title("Lower and upper bound on true effect")
  plt.legend()
  mode = 'non-nan'
  result_path = os.path.join(base_dir, f"final_{mode}_bounds.pdf")

  save_plot(fig, result_path)


# =============================================================================
# INDIVIDUAL RUN RESULTS
# =============================================================================

@empty_fig_on_failure
def plot_lagrangian(values: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  plt.plot(values)
  plt.xlabel("update steps")
  plt.ylabel("Lagrangian")
  plt.title(f"Overall Lagrangian")
  return fig


@empty_fig_on_failure
def plot_max_sq_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  tt = np.array([np.max((lhs - r)**2) for r in rhs])
  plt.semilogy(tt)
  plt.xlabel("optimization rounds")
  plt.ylabel("(LHS - RHS)^2")
  plt.title(f"(LHS - RHS)^2")
  return fig


@empty_fig_on_failure
def plot_max_abs_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  tt = np.array([np.max(np.abs(lhs - r)) for r in rhs])
  plt.semilogy(tt)
  plt.xlabel("optimization rounds")
  plt.ylabel("max($|LHS - RHS|$)")
  plt.title(f"max($|LHS - RHS|$)")
  return fig


@empty_fig_on_failure
def plot_max_rel_abs_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  tt = np.array([np.max(np.abs((lhs - r) / lhs)) for r in rhs])
  plt.semilogy(tt)
  plt.xlabel("optimization rounds")
  plt.ylabel("max($|LHS - RHS| / |LHS|$)")
  plt.title(f"max($|LHS - RHS| / |LHS|$)")
  return fig


@empty_fig_on_failure
def plot_abs_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  tt = np.array([np.abs(lhs - r) for r in rhs])
  for i in range(len(lhs)):
    plt.semilogy(tt[:, i], label=f'{i + 1}')
  plt.xlabel("optimization rounds")
  plt.ylabel("|LHS - RHS|")
  plt.title(f"individual |LHS - RHS|")
  plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
  return fig


@empty_fig_on_failure
def plot_min_max_rhs(rhs: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  tt = np.array([(np.min(r), np.max(r)) for r in rhs])
  plt.plot(tt)
  plt.xlabel("optimization rounds")
  plt.ylabel("RHS min and max")
  plt.title(f"min and max of RHS")
  return fig


@empty_fig_on_failure
def plot_tau(taus: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  plt.plot(taus, "-x")
  plt.xlabel("optimization rounds")
  plt.ylabel(f"temperature $\\tau$")
  plt.title(f"temperature parameter $\\tau$")
  return fig


@empty_fig_on_failure
def plot_lambda(lmbdas: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  plt.plot(lmbdas, "-x", linewidth=0.8)
  plt.xlabel("optimization rounds")
  plt.ylabel(f"multipliers $\lambda$")
  plt.title(f"Lagrange multipliers $\lambda$")
  return fig


@empty_fig_on_failure
def plot_objective(objectives: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  plt.plot(objectives)
  plt.xlabel("optimization rounds")
  plt.ylabel(f"objective value")
  plt.title("Objective")
  return fig


@empty_fig_on_failure
def plot_constraint_term(constrs: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  plt.semilogy(constrs)
  plt.xlabel("optimization rounds")
  plt.ylabel(f"constraint term")
  plt.title("Constraint term")
  return fig


  # Uncomment to show maximum absolute violation of exact constraints
  # result_path = os.path.join(base_dir, f"final_{mode}_maxabsdiff.pdf")
  # save_plot(plot_final_max_abs_diff(xstar_grid, results["maxabsdiff"]),
  #           result_path)


def plot_all(results,
             x: np.ndarray,
             y: np.ndarray,
             #response,
             base_dir: Optional[Text] = None, suffix: Text = ""):
  """Call all relevant plotting functions.

  Args:
      results: The results dictionary.
      x: The x values of the original data.
      y: The y values of the original data.
      response: The response function.
      base_dir: The path where to store the figures. If `None` don't save the
            figures to disk.
      suffix: An optional suffix to each filename stored by this function.
  """
  if base_dir is not None and not os.path.exists(base_dir):
    os.makedirs(base_dir)

  def get_filename(base: Text, fname: Text):
    return None if base is None else os.path.join(base, fname)

  suff = "_" + suffix if suffix else suffix

  name = "lagrangian{}.pdf".format(suff)
  save_plot(plot_lagrangian(results["lagrangian"]),
            get_filename(base_dir, name))


  name = "tau{}.pdf".format(suff)
  save_plot(plot_tau(results["tau"]),
            get_filename(base_dir, name))

  name = "lambda{}.pdf".format(suff)
  save_plot(plot_lambda(results["lambda"]),
            get_filename(base_dir, name))

  name = "objective{}.pdf".format(suff)
  save_plot(plot_objective(results["objective"]),
            get_filename(base_dir, name))

  name = "constraint_term{}.pdf".format(suff)
  save_plot(plot_constraint_term(results["constraint_term"]),
            get_filename(base_dir, name))

  name = "max_abs_lhs_rhs{}.pdf".format(suff)
  save_plot(plot_max_abs_lhs_rhs(results["lhs"], results["rhs"]),
            get_filename(base_dir, name))

  name = "max_rel_abs_lhs_rhs{}.pdf".format(suff)
  save_plot(plot_max_rel_abs_lhs_rhs(results["lhs"], results["rhs"]),
            get_filename(base_dir, name))

  #name = "abs_lhs_rhs{}.pdf".format(suff)
  #save_plot(plot_abs_lhs_rhs(results["lhs"], results["rhs"]),
  #          get_filename(base_dir, name))
