import matplotlib.pyplot as plt
from CLOUD import *
import torch
from Numeric_edit import *


def result_and_plot(model, after_learning, PATH, exp_name):

    u_p_out = after_learning["u_p_out"]
    exterior_psi = after_learning["exterior_psi"]
    interior_psi = after_learning["interior_psi"]
    boundary_psi = after_learning["boundary_psi"]

    u_NN_exterior = model.get_u_ext(exterior_psi).detach().cpu().reshape(-1)

    xspace = exterior_psi[:,0].detach().cpu()
    yspace = exterior_psi[:,1].detach().cpu()

    xspace = xspace.reshape(-1)
    yspace = yspace.reshape(-1)

    Numerics_L2_norm = torch.mean(torch.square(u_p_out))

    Relative_L2_error_pointwise = torch.divide(torch.square(u_NN_exterior - u_p_out) , Numerics_L2_norm)
    Relative_L2_error = torch.mean(Relative_L2_error_pointwise)

    H = model.background_field(exterior_psi).detach().cpu().reshape(-1)

    Background_L2_error_pointwise = torch.square(u_NN_exterior - H)
    Background_L2_error = torch.mean(Background_L2_error_pointwise)  

    print("### Final results ---- Relative_L2_error with numerical simulation : %.2E, Background_L2_error : %.2E ###"%(Relative_L2_error, Background_L2_error))

    #########################################

    points_masked = torch.cat((interior_psi, boundary_psi), dim=0)

    xspace_masked = points_masked[:,0].detach().cpu()
    yspace_masked = points_masked[:,1].detach().cpu()

    xspace_masked = xspace_masked.reshape(-1)
    yspace_masked = yspace_masked.reshape(-1)

    output_zero_dummy = torch.zeros(xspace_masked.shape[0])

    xspace_plot = torch.cat((xspace, xspace_masked), dim=0).reshape(-1).numpy()
    yspace_plot = torch.cat((yspace, yspace_masked), dim=0).reshape(-1).numpy()

    u_NN_exterior_plot = torch.cat((u_NN_exterior, output_zero_dummy), dim=0).numpy()
    u_p_out_plot = torch.cat((u_p_out, output_zero_dummy), dim=0).numpy()

    Relative_L2_error_plot = torch.cat((Relative_L2_error_pointwise, output_zero_dummy), dim=0).numpy()
    Background_L2_error_plot = torch.cat((Background_L2_error_pointwise, output_zero_dummy), dim=0).numpy()

    #### Loss information
    # Relative_L2_loss_componentwise = np.divide(np.square(prediction-numerical),np.square(numerical))
    # Relative_loss_total = np.divide(np.sum(np.abs(prediction-numerical)), np.sum(np.abs(numerical)))
    # print('L_1_loss_of_up_uNN : %f'%L_1_loss_of_up_uNN)
    # print('Relative_loss_total : %f'%Relative_loss_total)


    ## Make several Figures
    plt.rcParams.update(plt.rcParamsDefault)
    fig = plt.figure(figsize=(32, 6))

    ## Figure 1
    ax1 = fig.add_subplot(141)
    ax1.set_xlabel('x', size=10)
    ax1.set_ylabel('y', size=10)
    plt.tricontourf(xspace_plot, yspace_plot, u_NN_exterior_plot, levels=40, cmap='RdBu_r')
    ax1.set_title('Prediction', fontsize = 10) # font size doubled
    ax1.tick_params(labelsize=15)

    plt.colorbar()

    ## Figure 2
    ax2 = fig.add_subplot(142)
    ax2.set_xlabel('x', size=10)
    ax2.set_ylabel('y', size=10)
    plt.tricontourf(xspace_plot, yspace_plot, u_p_out_plot, levels=40, cmap='RdBu_r')
    ax2.set_title('Numerical', fontsize = 10) # font size doubled
    ax2.tick_params(labelsize=15)
    plt.colorbar()

    ## Figure 3
    ax3 = fig.add_subplot(143)
    ax3.set_xlabel('x', size=10)
    ax3.set_ylabel('y', size=10)
    plt.tricontourf(xspace_plot, yspace_plot, Relative_L2_error_plot, levels=40, cmap='viridis') # Relative L2 loss
    ax3.set_title('Relative L2 error', fontsize = 10) # font size doubled
    ax3.tick_params(labelsize=15)

    plt.colorbar()

    ## Figure 4
    ax4 = fig.add_subplot(144)
    ax4.set_xlabel('x', size=10)
    ax4.set_ylabel('y', size=10)
    plt.tricontourf(xspace_plot, yspace_plot, Background_L2_error_plot, levels=40, cmap='viridis') # Relative L2 loss
    ax4.set_title('Background L2 error', fontsize = 10) # font size doubled
    ax4.tick_params(labelsize=15)

    plt.colorbar()


    plt.savefig(PATH+"/plot of prediction and numerical solution_%s.png" %(exp_name))

    # plt.show()

    dict_result_error = {
        "Relative_L2_error" : Relative_L2_error,
        "Background_L2_error" : Background_L2_error
    }    
    
    return dict_result_error