## Import libraries.
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick



def plot_loss_inv(args, model, PATH, exp_name):

    fig = plt.figure(figsize=(35, 10))
    ax=plt.subplot(2,3,1)
    ax.semilogy(range(len(model.hist[0])), model.hist[0],'-')
    ax.set_xlabel('$n_{epoch}$')
    ax.set_ylabel('$\\phi_{n_{epoch}}$');
    ax.set_title('Total loss')

    ax=plt.subplot(2, 3, 2)
    ax.semilogy(range(len(model.hist[1])), model.hist[1],'-')
    ax.set_xlabel('$n_{epoch}$')
    ax.set_ylabel('$\\phi_{n_{epoch}}$');
    ax.set_title('loss for condition 1')

    ax=plt.subplot(2, 3, 3)
    ax.semilogy(range(len(model.hist[2])), model.hist[2],'-')
    ax.set_xlabel('$n_{epoch}$')
    ax.set_ylabel('$\\phi_{n_{epoch}}$');
    ax.set_title('loss for condition 2')

    ax=plt.subplot(2, 3, 4)
    ax.semilogy(range(len(model.hist[3])), model.hist[3],'-')
    ax.set_xlabel('$n_{epoch}$')
    ax.set_ylabel('$\\phi_{n_{epoch}}$');
    ax.set_title('loss for condition 3')

    ax=plt.subplot(2, 3, 5)
    ax.semilogy(range(len(model.hist[4])), model.hist[4],'-')
    ax.set_xlabel('$n_{epoch}$')
    ax.set_ylabel('$\\phi_{n_{epoch}}$');
    ax.set_title('loss for condition 4')

    ax=plt.subplot(2, 3, 6)
    ax.semilogy(range(len(model.hist_numeric)), model.hist_numeric,'-')
    ax.set_xlabel('$n_{epoch}$')
    ax.set_ylabel('$\\phi_{n_{epoch}}$');
    ax.set_title('loss for PINN vs. Numerical code')


    plt.savefig(PATH+"/plot of loss_%s.png" %(exp_name))

    # plt.show()

    if args.p_nn_flag == "False":

    ##################################
    ######## Plot of inverse_parameters/ if the interface has coeffs-expansion
    ##################################

        num_row = (args.order_interface) // 5 + 1
        num_col = 5
        fig = plt.figure(figsize=(num_col*7, num_row*5))

        ax=plt.subplot(num_row, num_col, 1)

        ax.set_yscale('symlog')
        ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e'))
        ax.plot(range(len(model.hist_inv[0])), model.hist_inv[0],'-')

        y_note = model.hist_inv[0][-1]
        ax.hlines(y=y_note, xmin = 100, xmax = len(model.hist_inv[0]), linestyle='--', linewidth=1, color='r')
        ax.text(len(model.hist_inv[0]), y_note, f'{y_note:.2e}', verticalalignment='bottom', color='r')

        ax.set_xlabel('$n_{epoch}$')
        # ax.set_ylabel('$\\phi_{n_{epoch}}$');
        ax.set_title('Inverse parameter p0')

        for i in range(1, args.order_interface+1):  # index iterates 1 to (order)
            ax=plt.subplot(num_row, num_col, (i+1))
            ax.set_yscale('symlog')
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e'))
            ax.plot(range(len(model.hist_inv[i])), model.hist_inv[i],'-')
            
            ####
            y_note = model.hist_inv[i][-1]
            ax.hlines(y=y_note, xmin = 100, xmax = len(model.hist_inv[i]), linestyle='--', linewidth=1, color='r')
            ax.text(len(model.hist_inv[i]), y_note, f'{y_note:.2e}', verticalalignment='bottom', color='r')

            ####

            ax.set_xlabel('$n_{epoch}$')
            # ax.set_ylabel('$\\phi_{n_{epoch}}$');
            ax.set_title('Inverse parameter p%d'%(i))
            
        plt.savefig(PATH+"/plot of inverse parameters(reals)_%s.png" %(exp_name))

        # plt.show()

