import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np

# IEEE style colors - widely used in scientific publications
IEEE_COLORS = [
  "#0072BD",  # blue
  "#D95319",  # orange
  "#EDB120",  # yellow
  "#7E2F8E",  # purple
  "#77AC30",  # green
  "#4DBEEE",  # light-blue
  "#A2142F",  # red
]

# Science style colors - based on Science journal
SCIENCE_COLORS = [
  "#0C5DA5",  # blue
  "#00B945",  # green
  "#FF9500",  # orange
  "#FF2C00",  # red
  "#845B97",  # purple
  "#474747",  # gray
  "#9e9e9e",  # light gray
]

# Nature style colors - based on Nature journal
NATURE_COLORS = [
  "#E64B35",  # red
  "#4DBBD5",  # blue
  "#00A087",  # cyan
  "#3C5488",  # dark blue
  "#F39B7F",  # light red
  "#8491B4",  # light blue
  "#91D1C2",  # light cyan
]


def plot_miscoverage_rate_vs_alpha(
  alpha_levels: np.ndarray,
  miscoverage_rates: np.ndarray,
  *,
  ax: plt.Axes | None = None,
  label: str | None = None,
):
  if ax is None:
    _, ax = plt.subplots(figsize=(8, 5))

  ax.plot(alpha_levels, alpha_levels, "g--", label="Optimal")
  ax.fill_between(
    alpha_levels,
    miscoverage_rates,
    alpha_levels,
    where=miscoverage_rates > alpha_levels,
    interpolate=True,
    color="red",
    alpha=0.2,
  )
  ax.fill_between(
    alpha_levels,
    miscoverage_rates,
    alpha_levels,
    where=miscoverage_rates < alpha_levels,
    interpolate=True,
    color="green",
    alpha=0.2,
  )
  plot_label = label if label is not None else "miscoverage rate"
  line = ax.plot(alpha_levels, miscoverage_rates, label=plot_label)[0]
  # Legend entries
  red_patch = mpatches.Patch(color="red", alpha=0.2, label="miscoverage rate > alpha")
  blue_patch = mpatches.Patch(
    color="green", alpha=0.2, label="miscoverage rate < alpha"
  )

  ax.set_xlabel("Alpha Level")
  ax.set_ylabel("Miscoverage Rate")
  # Get all the legend handles
  handles = [red_patch, blue_patch, line]
  ax.legend(handles=handles)

  return ax


def add_identity_line(
  ax, color="k", style="--", alpha=0.5, label="Identity Line", equal_axes=True
):
  """
  Add an identity line (y=x) to a matplotlib axis.

  Parameters:
  -----------
  ax : matplotlib.axes.Axes
      The axis to add the identity line to
  color : str, default='k'
      Color of the line
  style : str, default='--'
      Line style (e.g., '-' for solid, '--' for dashed)
  alpha : float, default=0.5
      Transparency of the line (0 to 1)
  label : str, default='Identity Line'
      Label for the line in the legend. Set to None to omit from legend
  equal_axes : bool, default=True
      If True, sets both axes to have the same limits

  Returns:
  --------
  line : matplotlib.lines.Line2D
      The identity line object
  """
  # Get current axis limits
  xlim = ax.get_xlim()
  ylim = ax.get_ylim()

  # Find min and max for the line
  min_val = min(xlim[0], ylim[0])
  max_val = max(xlim[1], ylim[1])

  # Plot the identity line
  line = ax.plot(
    [min_val, max_val], [min_val, max_val], color + style, label=label, alpha=alpha
  )[0]

  # Set equal axes if requested
  if equal_axes:
    ax.set_xlim(min_val, max_val)
    ax.set_ylim(min_val, max_val)

  return line
