from matplotlib.axes import Axes


def force_tick_font(ax: Axes, xfmt=id, yfmt=id):
  # Force update for x-axis
  x_limits = ax.get_xlim()
  y_limits = ax.get_ylim()

  x_values = ax.get_xticks()
  print(x_values)
  x_labels = [str(xfmt(value)) for value in x_values]
  ax.set_xticks(x_values)  # Ensure ticks align with their current positions
  ax.set_xticklabels(x_labels)

  # Force update for y-axis
  y_values = ax.get_yticks()
  y_labels = [str(yfmt(value)) for value in y_values]
  ax.set_yticks(y_values)  # Ensure ticks align with their current positions
  ax.set_yticklabels(y_labels)

  # Restore limits
  ax.set_xlim(x_limits)
  ax.set_ylim(y_limits)


def style_tr(ax: Axes):
  ax.set_xlim(0, 1e6)
  ax.set_ylim(0, 1)
  ax.spines["top"].set_visible(False)
  ax.spines["right"].set_visible(False)
  ax.spines["left"].set_linewidth(1)
  ax.spines["bottom"].set_linewidth(1)

  ax.tick_params(axis="x", which="both", width=1, length=5, direction="out")
  ax.tick_params(axis="y", which="both", width=1, length=5, direction="out")

  ax.set_xticks([0, 5e5, 1e6])
  ax.set_xticklabels([rf"{natfmt(i)}" for i in ax.get_xticks()])
  ax.set_yticks([0.0, 0.5, 1.0])
  ax.set_yticklabels([rf"{int(i * 100)}\%" for i in ax.get_yticks()])


def natfmt(x):
  if abs(x) < 1e3:
    x, suffix = x, ""
  elif 1e3 <= abs(x) < 1e6:
    x, suffix = x / 1e3, "K"
  elif 1e6 <= abs(x) < 1e9:
    x, suffix = x / 1e6, "M"
  elif 1e9 <= abs(x):
    x, suffix = x / 1e9, "B"
  else:
    raise ValueError("Value too small")
  if abs(x) <= 1:
    return f"{x:.0f}{suffix}"
  elif 1 <= abs(x) < 10:
    return f"{x:.1f}{suffix}"
  elif 10 <= abs(x):
    return f"{x:.0f}{suffix}"
