import jax
import jax.numpy as jnp
from jax import vmap


import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors

from copy import deepcopy

def snapshot_model_state(model):

    x = model.problem_params['x']
    u_exact = model.problem_params['u_exact'].reshape(-1)

    u_pred_c = vmap(model.u_net, (None, None, None, None, 0))(0, model.params_c, model.params_f, model.coeffs_c, x).ravel()
    u_pred_f = vmap(model.u_net, (None, None, None, None, 0))(1, model.params_c, model.params_f, model.coeffs_f, x).ravel()

    gate_net_c = model.level_c.gate_net
    gate_net_f = model.level_f.gate_net
    gate_pred_c = vmap(gate_net_c, (None, 0))(model.params_c[0], x)
    gate_pred_f = vmap(gate_net_f, (None, 0))(model.params_f[0], x)

    return dict(
        x=x, u_exact=u_exact,
        u_pred_c=u_pred_c, u_pred_f=u_pred_f,
        gate_pred_c=gate_pred_c, gate_pred_f=gate_pred_f,
        logs=deepcopy(model.logs),
        run_times=model.run_times
    )


def plot_losses(ax, loss_c, loss_f):
    ax.semilogy(loss_c, c="b", label="Coarse Gate Loss")
    ax.semilogy(loss_f, c="r", label="Fine Gate Loss")
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Loss")
    ax.grid()
    ax.legend()
    ax.set_title("Training Loss")

def plot_bc_losses(ax, loss_c, loss_f):
    ax.semilogy(loss_c, c="b", label="coarse")
    ax.semilogy(loss_f, c="r", label="fine")
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid()
    ax.set_title("BC Loss")

def plot_componentwise_losses(ax,
                              loss_comp_c, loss_coop_c,
                              loss_comp_f, loss_coop_f):
    ax.semilogy(loss_comp_c, c="b", label="comp(c)")
    ax.semilogy(loss_comp_f, c="r", label="comp(f)")
    ax.semilogy(loss_coop_c, c="b", label="coop(c)", linestyle='dashed')
    ax.semilogy(loss_coop_f, c="r", label="coop(f)", linestyle='dashed')
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid()
    ax.set_title("Component-wise Loss over Iterations")

def plot_rmse(ax, epochs, rmse_c, rmse_f):
    ax.semilogy(epochs, rmse_c, color="b", label="Coarse RMSE")
    ax.semilogy(epochs, rmse_f, color="r", label="Fine RMSE")
    ax.set_xlabel("Iterations")
    ax.set_ylabel("RMSE")
    ax.grid()
    ax.legend()
    ax.set_title("RMSE over Iterations")

def plot_gate_functions(model, axs, gate_pred_c, gate_pred_f, cut_off=0.15):

    x = model.problem_params['grid_axes']

    if len(x) == 1:
        x = x[0]

        for i in range(2):
            if i == 0:
                axs[i].plot(x, gate_pred_c, label="Coarse Gate Function")
            else:

                g_pred_f = jnp.einsum("di,dij->dij", gate_pred_c, gate_pred_f)
                axs[i].plot(x, g_pred_f.reshape(g_pred_f.shape[0], -1), '--', label="Fine Gate Function")

            axs[i].set_xlabel("x")
            axs[i].set_ylabel("Gate Function Value")
            #axs[i].legend()
            axs[i].set_title("Gate Function Predictions")

    elif len(x) == 2:


        # Set up
        x1, x2 = x[0], x[1]
        n_samples = x1.shape[0]
        X1, X2 = jnp.meshgrid(x1, x2, indexing='ij')
        cmap = cm.get_cmap("viridis", gate_pred_c.shape[1])
        # Coarse gate functions
        ax_c = axs[0]
        for i in range(gate_pred_c.shape[1]):
            Z = gate_pred_c[:, i].reshape(n_samples, n_samples)
            ax_c.contourf(X1, X2, Z > cut_off, levels=[cut_off, 1], colors=[cmap(i)], alpha=0.4)
        ax_c.set_title("All Coarse Gate Supports")
        ax_c.set_xlabel("$x$")

        # Fine gate functions
        g_pred_f = jnp.einsum("di,dij->dij", gate_pred_c, gate_pred_f)  # shape [N, n_c, n_f]
        g_pred_f = g_pred_f.reshape(g_pred_f.shape[0], -1)  # shape [N, n_total_fine]
        ax_f = axs[1]
        cmap = cm.get_cmap("plasma", g_pred_f.shape[1])
        for i in range(g_pred_f.shape[1]):
            Z = g_pred_f[:, i].reshape(n_samples, n_samples)
            ax_f.contourf(X1, X2, Z > cut_off, levels=[cut_off, 1], colors=[cmap(i)], alpha=0.4)
        ax_f.set_title("All Fine Gate Supports")
        ax_f.set_xlabel("$x$")



    else:
        pass

def plot_posterior(ax, x, posterior_c, posterior_f):
    ax.plot(x, posterior_c, label="Coarse Posterior")
    ax.plot(x, posterior_f.reshape(posterior_f.shape[0], -1), '--', label="Fine Posterior")
    ax.set_xlabel("$x$")
    ax.set_ylabel("Posterior Value")
    #ax.legend()
    ax.set_title("Posterior Predictions")

def plot_predictions(ax, x, u_exact, u_pred_c, u_pred_f):


    # if multi-dimensional just unwrap
    if len(x.shape) > 1 and x.shape[1] > 1:
        u_exact = u_exact.ravel()
        u_pred_c = u_pred_c.ravel()
        u_pred_f = u_pred_f.ravel()
        x = jnp.arange(u_exact.size)

    ax.plot(x, u_exact, label="Exact Solution")
    ax.plot(x, u_pred_c, '--', label="Coarse Prediction")
    ax.plot(x, u_pred_f, '--', label="Fine Prediction")
    ax.set_ylabel("$u(x)$")
    ax.legend()
    ax.set_title("Model Predictions")

    if len(u_exact.shape) > 1:
        ax.set_xlabel("$x$")

def plot_error(ax, x, u_exact, u_pred_c, u_pred_f):
    if len(x.shape) > 1 and x.shape[1] > 1:
        u_exact = u_exact.ravel()
        u_pred_c = u_pred_c.ravel()
        u_pred_f = u_pred_f.ravel()
        x = jnp.arange(u_exact.size)

    ax.semilogy(x, jnp.abs(u_pred_c-u_exact), '--', label="Coarse")
    ax.semilogy(x, jnp.abs(u_pred_f-u_exact), '--', label="Fine")
    ax.set_ylabel("$|u(x)-u_{exact}|$")
    ax.grid()
    ax.legend()
    ax.set_title("Model Prediction Error")

    if len(u_exact.shape) > 1:
        ax.set_xlabel("x")

def plot_all_metrics(model):
    """
    Plots all relevant metrics in a single figure with subplots.

    Args:
        model: Trained model instance containing training logs and parameters.
    """


    x = model.problem_params['x']
    u_exact = model.problem_params['u_exact'].reshape(-1)
    # Generate data
    u_pred_c = vmap(model.u_net, (None, None, None, None, 0))(0, model.params_c, model.params_f, model.coeffs_c, x).ravel()
    u_pred_f = vmap(model.u_net, (None, None, None, None, 0))(1, model.params_c, model.params_f, model.coeffs_f, x).ravel()

    gate_net_c = model.level_c.gate_net
    gate_net_f = model.level_f.gate_net

    gate_pred_c = vmap(gate_net_c, (None, 0))(model.params_c[0], x)
    gate_pred_f = vmap(gate_net_f, (None, 0))(model.params_f[0], x)
    posterior_c = model.E_step(0, model.params_c, model.params_f, model.coeffs_c)
    posterior_f = model.E_step(1, model.params_c, model.params_f, model.coeffs_f)
    dim = len(model.problem_params['grid_axes'])

    nrows = 2 #if dim < 3 else 1

    fig, axs = plt.subplots(nrows, 3, figsize=(12, 4*nrows))
    #axs = [axs] if nrows == 1 else axs

    # RMSE
    epochs = model.training_params['log_interval'] * jnp.arange(len(model.logs["coarse"]["l2_error"]))
    plot_rmse(axs[0, 0], epochs,
              model.logs["coarse"]["l2_error"], model.logs["fine"]["l2_error"])
    # losses
    #plot_bc_losses(axs[0, 1],
    #          model.logs["coarse"]["loss_bcs"], model.logs["fine"]["loss_bcs"])
    # losses
    #plot_componentwise_losses(axs[0, 2],
    #          model.logs["coarse"]["loss_comp"], model.logs["coarse"]["loss_coop"],
    #          model.logs["fine"]["loss_comp"], model.logs["fine"]["loss_coop"])


    # plot predictions
    plot_predictions(axs[0, 1], x, u_exact, u_pred_c, u_pred_f)
    plot_error(axs[0, 2], x, u_exact, u_pred_c, u_pred_f)

    # gate functions
    if dim < 3:
        plot_gate_functions(model, axs[1,:], gate_pred_c, gate_pred_f)

    # posterior
    if dim == 1:
        plot_posterior(axs[1, 2], x, posterior_c, posterior_f)

    # hide last subplot
    #fig.delaxes(axs[1, 2])

    plt.tight_layout()
    plt.show()

    return u_pred_c, u_pred_f, gate_net_c, gate_net_f
