"""
Plotting utils: Can compute the gradient of a vanilla-DeepONet's prediction w.r.t. given inputs (please see the python code for one of the problems to see the sequence of the inputs passed)
"""

import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import jax.numpy as jnp
import jax


def get_max_grad_over_funcs(func, weights, input_, idx, num_groups=None, M=None):
    # since input_ is a tuple, idx is the index of the gradients to choose
    # assuming branch and trunk inputs are 0 and 1 index respectively in input_
    num_funcs = input_[0].shape[0]

    if num_groups is None:
        def forward_func(w, t): return func(w, None, t).sum()
    else:
        def forward_func(w, t): return func(w, None, t, num_groups, M).sum()

    arr = []
    for i in range(num_funcs):
        if len(input_[0].shape) > 2:
            cur_func = input_[0][jnp.newaxis, i, :, :, :]
        else:
            cur_func = input_[0][i, :]
        temp_input_list = [cur_func]
        for j in range(1, len(input_)):
            temp_input_list.append(input_[j])

        grad_vals = jax.grad(forward_func, argnums=1, allow_int=True)(weights, temp_input_list)[idx]
        grad_vals_norm = jnp.linalg.norm(grad_vals, ord=2, axis=1).flatten()
        arr.append(grad_vals_norm)

    # getting the max over each point across all funcs
    arr = jnp.stack(arr).max(0)
    return arr


def format_ticks(value, pos):
    if value == int(value):
        return "{:.0f}".format(value)
    else:
        return value


def plot_helper(tri,
                plot_arr,
                title,
                save_file_name,
                clim_lower,
                clim_upper,
                axis_font,
                tick_fontsize,
                title_font,
                figsize,
                ticks=None):

    plt.rcParams["mathtext.fontset"] = "cm"
    plt.rcParams["axes.linewidth"] = 0.5
    fig, ax = plt.subplots(figsize=figsize)
    for axes in ["top", "bottom", "left", "right"]:
        ax.spines[axes].set_linewidth(2.)
    plt.tripcolor(tri, plot_arr, cmap='jet', edgecolors='k', shading='gouraud')
    plt.clim(clim_lower, clim_upper)
    cbar = plt.colorbar()
    cbar.ax.yaxis.set_tick_params(labelsize=axis_font)
    cbar.ax.yaxis.get_offset_text().set_fontsize(50)
    cbar.ax.tick_params(labelsize=tick_fontsize)
    if title is not None:
        plt.title(title, fontsize=title_font, pad=20)
    plt.xlabel(r"$y_1$", fontsize=axis_font)
    plt.ylabel(r"$y_2$", fontsize=axis_font)
    if ticks is not None:
        plt.xticks(ticks, fontsize=tick_fontsize)
        plt.yticks(ticks, fontsize=tick_fontsize)
    else:
        plt.xticks([0, 0.5, 1], fontsize=tick_fontsize)
        plt.yticks([0, 0.5, 1], fontsize=tick_fontsize)
    plt.gca().xaxis.set_major_formatter(FuncFormatter(format_ticks))
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_ticks))
    plt.box(on=True)
    fig.savefig(save_file_name, bbox_inches="tight")
    plt.cla()
    plt.close()


def plot_helper_3d(X,
                   plot_arr,
                   title,
                   save_file_name,
                   clim_lower,
                   clim_upper,
                   axis_font,
                   tick_fontsize,
                   title_font,
                   figsize,
                   ticks=None):

    plot_arr = plot_arr["a"]
    X = X["X_out"]
    plt.rcParams["mathtext.fontset"] = "cm"
    plt.rcParams["axes.linewidth"] = 0.5
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    scat = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=plot_arr.flatten(), cmap="jet", s=55, vmin=clim_lower, vmax=clim_upper)
    ax.set_xlabel(r"$y_1$", fontsize=axis_font)
    ax.set_ylabel(r"$y_2$", fontsize=axis_font)
    ax.set_zlabel(r"$y_3$", fontsize=axis_font)
    ax.set_xticks([-1, -0.5, 0, 0.5, 1])
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
    ax.set_zticks([-1, -0.5, 0, 0.5, 1])
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    if title is not None:
        ax.set_title(title, fontsize=title_font)
    ax.xaxis.set_major_formatter(FuncFormatter(format_ticks))
    ax.yaxis.set_major_formatter(FuncFormatter(format_ticks))
    f = 1.1
    ax.auto_scale_xyz([-f, f], [-f, f], [-f, f])
    cbar = fig.colorbar(scat, shrink=0.5)
    cbar.ax.yaxis.get_offset_text().set_fontsize(axis_font-20)
    cbar.ax.tick_params(labelsize=tick_fontsize)
    ax.set_box_aspect([2, 2, 2])
    ax.xaxis.labelpad = 50
    ax.yaxis.labelpad = 50
    ax.zaxis.labelpad = 50
    ax.view_init(azim=270)
    fig.savefig(save_file_name, bbox_inches="tight")
    plt.cla()
    plt.close()

