from .utils import get_sharpness_grid
import matplotlib.pyplot as plt
import numpy as np

def eos_conf_ax(ax, fig, eos_mins, mu=1, copies=2, color='k'):
    x_l, x_h = ax.get_xlim()
    y_l, y_h = ax.get_ylim()
    ax.plot(eos_mins[0], eos_mins[1], 'o', ms=4, markerfacecolor="None",
            markeredgecolor=color, markeredgewidth=1)
    base_x = np.linspace(x_l, x_h, 1000)
    ax.plot(base_x, np.abs(mu ** (1/copies)/(base_x + 1e-18)), lw=0.8, color=color)
    ax.plot(base_x, -np.abs(mu ** (1/copies)/(base_x + 1e-18)), lw=0.8, color=color)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(x_l, x_h)
    ax.set_ylim(y_l, y_h)

def eos_conf_ax_power(ax, fig, eos_mins, power):
    x_l, x_h = ax.get_xlim()
    y_l, y_h = ax.get_ylim()
    ax.plot(eos_mins[0]**power, eos_mins[1]**power, 'o', ms=4, markerfacecolor="None",
            markeredgecolor='black', markeredgewidth=1)
    base_x = np.linspace(x_l**(1/power), x_h**(1/power), 1000)
    ax.plot(base_x**power, np.abs(1/(base_x + 1e-18))**power, lw=0.8, color='k')
    ax.set_xlabel(r'$x^{}$'.format(power))
    ax.set_ylabel(r'$y^{}$'.format(power))
    ax.set_xlim(x_l, x_h)
    ax.set_ylim(y_l, y_h)

def sharpness_ax(ax, fig, mu, density=100):
    x_l, x_h = ax.get_xlim()
    y_l, y_h = ax.get_ylim()
    x_grid = np.linspace(x_l, x_h, density)
    y_grid = np.linspace(y_l, y_h, density)
    sharpness_grid = get_sharpness_grid(x_grid, y_grid, mu)
    im = ax.imshow(sharpness_grid, origin='lower', extent=(x_l, x_h, y_l, y_h), cmap='coolwarm', zorder=-1)
    clb = fig.colorbar(im, ax=ax, shrink=0.6)
    clb.ax.set_title('sharpness',fontsize=8)
    ax.set_xlim(x_l, x_h)
    ax.set_ylim(y_l, y_h)

def residual_ax(ax, fig, mu, density=200):
    x_l, x_h = ax.get_xlim()
    y_l, y_h = ax.get_ylim()
    x_grid = np.linspace(x_l, x_h, density)
    y_grid = np.linspace(y_l, y_h, density)
    x_mesh, y_mesh = np.meshgrid(x_grid, y_grid)
    residual_grid = x_mesh ** 2 * y_mesh ** 2 - mu
    residual_grid = residual_grid.reshape(density, density)
    im = ax.imshow(residual_grid, origin='lower', extent=(x_l, x_h, y_l, y_h), cmap='coolwarm', zorder=-1)
    clb = fig.colorbar(im, ax=ax, shrink=0.6)
    clb.ax.set_title('residual',fontsize=8)
    ax.set_xlim(x_l, x_h)
    ax.set_ylim(y_l, y_h)