import numpy as np
import matplotlib.pyplot as plt


def create_plot(name, points, nash_gaps, r_values, c_values, logbs, xs, ys, vs, alpha, elev, azim, z_notation=False, ylim=None):
    x_points, y_points = zip(*points)
    z_points = [v + 0.05 for v in r_values]

    gap_x = [gx for gx, _ in nash_gaps]
    gap_y = [gy for _, gy in nash_gaps]

    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "axes.labelsize": 20,
        "font.size": 18,
        "legend.fontsize": 16,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
    })

    fig1 = plt.figure(figsize=(8, 7))
    ax1 = fig1.add_subplot(111, projection='3d')

    LOGB_CLEAN = np.copy(logbs)
    nan_mask = np.isnan(LOGB_CLEAN)
    LOGB_CLEAN[nan_mask] = -15  # value for masked regions
    norm = plt.Normalize(vmin=-10, vmax=-0.5)
    colors = plt.cm.viridis(norm(LOGB_CLEAN))
    colors[nan_mask] = [0.8, 0.8, 0.8, 1.0]  # gray for NaN

    ax1.plot_surface(xs, ys, vs, facecolors=colors, shade=True, alpha=1.0, rstride=3, cstride=3)
    ax1.scatter(x_points, y_points, z_points, color='red', s=3, alpha=1.0)

    ax1.set_xlabel(r'$x$', labelpad=12)
    ax1.set_ylabel(r'$y$', labelpad=12)
    ax1.view_init(elev=elev, azim=azim)

    plt.tight_layout()
    fig1.savefig(name + '_3d_surface.pdf', bbox_inches='tight', dpi=300)
    plt.close(fig1)

    fig2 = plt.figure(figsize=(8, 4.5))
    ax2 = fig2.add_subplot(111)

    if z_notation:
        l1 = r'Nash-Gap$_x(z^{(t)})$'
        l2 = r'Nash-Gap$_y(z^{(t)})$'
        c = r'$c(z^{(t)})$'
    else:
        l1 = r'Nash-Gap$_1(x^{(t)})$'
        l2 = r'Nash-Gap$_2(x^{(t)})$'
        c = r'$c(x^{(t)})$'

    p1, = ax2.plot(range(len(nash_gaps)), gap_x, label=l1)
    p2, = ax2.plot(range(len(nash_gaps)), gap_y, label=l2)
    ax3 = ax2.twinx()
    p3, = ax3.plot(range(len(c_values)), c_values, label=c, color='green')
    p4, = ax3.plot(range(len(c_values)), [alpha] * len(c_values), color='red', linestyle='--', label=r'$\alpha$')
    if ylim is not None:
        ax3.set_ylim(ylim)
    else:
        ax3.set_ylim(0, max(max(c_values), 1 - alpha) * 1.5)

    plots = [p1, p2, p3, p4]
    labels = [p.get_label() for p in plots]

    ax2.set_xlabel(r'$t$')
    if z_notation:
        ax2.set_ylabel(r'Nash-Gap$_i(z^{(t)})$', rotation=90)
        ax3.set_ylabel(r'$c(z^{(t)})$', rotation=90)
    else:
        ax2.set_ylabel(r'Nash-Gap$_i(x^{(t)})$', rotation=90)
        ax3.set_ylabel(r'$c(x^{(t)})$', rotation=90)
    ax2.legend(plots, labels, loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax2.transAxes)

    ax2.grid(True)

    plt.tight_layout()
    fig2.savefig(name + '_nash_gap.pdf', bbox_inches='tight', dpi=300)
    plt.close(fig2)


def create_plot_N(nash_gaps, r_values, c_values, alpha_1, alpha_2):
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "axes.labelsize": 20,
        "font.size": 18,
        "legend.fontsize": 16,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
    })

    fig2 = plt.figure(figsize=(8, 4.5))
    ax2 = fig2.add_subplot(111)

    l = r'Nash-Gap$(x^{(t)})$'
    c = r'$c(x^{(t)})$'

    p1, = ax2.plot(range(len(nash_gaps)), nash_gaps, label=l)
    ax3 = ax2.twinx()
    
    c_values = np.array(c_values)  # shape (T, 5)
    p_cs = []
    colors = ['limegreen', 'darkgreen']
    for k in range(c_values.shape[1]):
        p_k, = ax3.plot(range(len(c_values)), c_values[:, k], label=rf'$c_{k+1}(x^{{(t)}})$', color=colors[k % len(colors)])
        p_cs.append(p_k)

    # alpha threshold
    p_alpha, = ax3.plot(range(len(c_values)), [alpha_1] * len(c_values),
                        color='orangered', linestyle='--', label=r'$\alpha_1$')
    p_alpha_2, = ax3.plot(range(len(c_values)), [alpha_2] * len(c_values),
                        color='darkred', linestyle='--', label=r'$\alpha_2$')

    # set ylim based on all values
    ax3.set_ylim(0, 7.5)

    # collect handles and labels
    plots = [p1] + p_cs + [p_alpha, p_alpha_2]
    labels = [p.get_label() for p in plots]

    ax2.set_xlabel(r'$t$')
    ax2.set_ylabel(r'Nash-Gap$(x^{(t)})$', rotation=90)
    ax3.set_ylabel(r'$c_i(x^{(t)})$', rotation=90)
    ax2.legend(plots, labels, loc='upper right', bbox_to_anchor=(1, .94), bbox_transform=ax2.transAxes)

    ax2.grid(True)

    plt.tight_layout()
    fig2.savefig('routing_nash_gap.pdf', bbox_inches='tight', dpi=300)
    plt.close(fig2)
