import jax
from jax import vmap
from matplotlib import pyplot as plt
from matplotlib import rc

from new_natgrad.pushforward import pushforward_factory

jax.config.update("jax_enable_x64", True)

def visualize_ng(model, params, nat_grad, x, output_name):
    
    v_push = vmap(
        pushforward_factory(model)(params, nat_grad),
        (0),
    )
    
    rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
    rc('text', usetex=True)
    plt.rcParams.update({'font.size': 16})

    # plot stuff
    plt.scatter(x[:,0], x[:,1], c = v_push(x), s = 10)
    ax = plt.gca()
    ax.set_aspect(1.)
    plt.colorbar()

    plt.savefig(
        r'out/' + output_name + '.png', 
        bbox_inches="tight",
        dpi=400,
        )
    
    plt.close()

def visualize_er(model, params, u_star, x, output_name):
    
    v_model = vmap(model, (None, 0))
    v_u_star = vmap(u_star, (0))
    v_error = lambda x: v_u_star(x) - v_model(params, x)

    rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
    rc('text', usetex=True)
    plt.rcParams.update({'font.size': 16})

    # plot stuff
    plt.scatter(x[:,0], x[:,1], c = v_error(x), s = 10)
    ax = plt.gca()
    ax.set_aspect(1.)
    plt.colorbar()

    plt.savefig(
        r'out/' + output_name + '.png', 
        bbox_inches="tight",
        dpi=400,
        )
    
    plt.close()