import numpy as np 
import matplotlib.pyplot as plt

def save_interface_figure(args, after_learning, PATH, exp_name):

	boundary_disk_tspace = after_learning['boundary_disk_tspace']
	interface = after_learning['interface']

	if args.p_nn_flag == "True":
		interface = interface[:, [0]]
		
	#### detach() and numpy() ####

	boundary_disk_tspace = boundary_disk_tspace.detach().numpy()
	interface = interface.detach().numpy()

	#### flatten() ####

	boundary_disk_tspace_flat = boundary_disk_tspace.flatten()
	interface_flat = interface.flatten()


	#### sorting #### 

	sorted_indices = np.argsort(boundary_disk_tspace_flat)

	boundary_disk_tspace_sorted = boundary_disk_tspace_flat[sorted_indices]
	interface_sorted = interface_flat[sorted_indices]


	#### plot - p is NN ####

	plt.figure(figsize=(10, 6))
	plt.plot(boundary_disk_tspace_sorted, interface_sorted, label='Data points')  # Scatter plot of X vs Y
	plt.xlabel('X values')
	plt.ylabel('Y values')
	plt.title('Plot of function from X to Y')
	plt.legend()
	plt.grid(True)
	# plt.show()

	plt.savefig(PATH+"/plot of interface function_%s.png" %(exp_name))
