import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import FancyArrow
from sympy import *


# def flatten(matrix):
#     flat_list = []
#     for row in matrix:
#         flat_list += row
#     return flat_list

# delta = 0.025
# x = y = np.arange(-3.0, 3.01, delta)
# X, Y = np.meshgrid(x, y)
# Z1 = np.exp(-X**2 - Y**2)
# Z2 = np.exp(-(X - 1)**2 - (Y - 1)**2)
# Z = (Z1 - Z2) * 2

directory='plot_data/'
save_directory='saved/'
save_format='png'
save_plot=True
saved_dpi=1000
args={'figsize':(9,10),'dpi':100}
grad_arrow_width=0.025
avg_grad_arrow_width_mul=2.

points=[(0.28,-0.46),(0.5,-0.7),(0.95,-0.04),(0,0.55),(1.31,-1.11),(0.77,-1.81)]
fig1, axis = plt.subplots(3,2,layout='constrained',**args)

arrow_grad=None
arrow_avggrad=None

for point in points:
    ind=points.index(point)
    ax2=axis.flatten()[ind]
    #point=(0.28,-0.46)
    #point=(0.5,-0.7)

    move=point

    delta = 0.025
    x = y = np.arange(-2.0, 2.01, delta)
    X, Y = np.meshgrid(x+move[0], y+move[1])

    Z1 = np.exp(-np.abs(X)**2 - np.abs(Y)**2)
    Z2 = np.exp(0.5+0.5*(-np.abs((X + 0.5))**1.5 - np.abs((Y + 1.7))**1.5))
    Z3 = np.exp(1.5*(-np.abs((X - 1))**2.5 - np.abs((Y - 1))**2.5))
    Z4 = 1.7*np.exp(0.5*(-np.abs(X-2)**2 - np.abs(Y+1)**2))
    Z5 = np.exp(1.5*(-np.abs((X - 0.75))**2 - np.abs((Y + 2.5))**2))
    Z = Z2+Z1+Z3+Z4+Z5+0.3#(Z1 - Z2) * 2+

    x = Symbol('x')
    y=Symbol('y')
    z=exp(-abs(x)**2 - abs(y)**2)+exp(0.5+0.5*(-np.abs((x+ 0.5))**1.5 - abs((y + 1.7))**1.5))+exp(1.5*(-abs((x- 1))**2.5 - abs((y - 1))**2.5))\
        +1.7*exp(0.5*(-abs(x-2)**2 - abs(y+1)**2))\
        +exp(1.5*(-abs((x - 0.75))**2 - abs((y + 2.5))**2))\
        +0.3
    dz=z.diff(x,y)
    print(dz)

    # Z=np.empty_like(X)
    # for iy, ix in np.ndindex(X.shape):
    #     Z[ix,iy]=z.evalf(subs={x:X[ix,iy], y:Y[ix,iy]})


    update_step=1
    def inside_update_step_criterion(x,y,z):
        return (x-point[0])**2+(y-point[1])**2<=update_step**2

    def minimum_with_criterion(X,Y,Z,criterion=lambda x,y,z:True):
        ind_of_min = [0, 0]
        min_value=np.inf
        for iy, ix in np.ndindex(X.shape):
            #min_val=np.inf

            if criterion(X[ix,iy],Y[ix,iy],Z[ix,iy]):

                if min_value>Z[ix,iy]:
                    #print(Z[ix, iy])
                    ind_of_min[0]=ix
                    ind_of_min[1] = iy
                    min_value=Z[ind_of_min[0],ind_of_min[1]]
        return ind_of_min



    CS = ax2.contourf(X, Y, Z, 25)#, cmap=plt.cm.bone)




    min_coords=minimum_with_criterion(X,Y,Z,inside_update_step_criterion)
    #ax2.arrow(point[0],point[1],X[min_coords[0],min_coords[1]]-point[0],Y[min_coords[0],min_coords[1]]-point[1],width=0.02,color='black')
    arrow_width=grad_arrow_width*avg_grad_arrow_width_mul
    shorten_arrow=0.#arrow_width*4
    vec_len=((X[min_coords[0],min_coords[1]]-point[0])**2+(Y[min_coords[0],min_coords[1]]-point[1])**2)**0.5
    #ax2.add_patch(matplotlib.patches.FancyArrow(point[0],point[1],(X[min_coords[0],min_coords[1]]-point[0])-(X[min_coords[0],min_coords[1]]-point[0])/vec_len*shorten_arrow,(Y[min_coords[0],min_coords[1]]-point[1])-(Y[min_coords[0],min_coords[1]]-point[1])/vec_len*shorten_arrow,
    #                                            width=arrow_width,color='black',legend='Average Gradient Direction'))
    arrow_avggrad=ax2.arrow(point[0],point[1],float((X[min_coords[0],min_coords[1]]-point[0])-(X[min_coords[0],min_coords[1]]-point[0])/vec_len*shorten_arrow),float((Y[min_coords[0],min_coords[1]]-point[1])-(Y[min_coords[0],min_coords[1]]-point[1])/vec_len*shorten_arrow),
                                                width=arrow_width,color='white',length_includes_head=True)
    #derivative=diff(z).evalf(subs={x:point[0], y:point[1]})#dz(point[0],point[1])
    #print(Symbol.diff(z(x,y),x))
    #print(derivative)
    def derivative(f,x,delta=0.0001):
        return (f(x+delta)-f(x))/delta

    dz=[derivative(lambda _x:z.evalf(subs={x:_x, y:point[1]}),point[0]),derivative(lambda _x:z.evalf(subs={x:point[0], y:_x}),point[1])]
    length=sqrt(dz[0]**2+dz[1]**2)
    dz[0]=update_step/length*dz[0]
    dz[1]=update_step/length*dz[1]

    arrow_width = grad_arrow_width
    shorten_arrow = 0.#arrow_width * 4

    arrow_grad=ax2.arrow(point[0],point[1],-float(dz[0])-(-float(dz[0]))/update_step*shorten_arrow,-float(dz[1])-(-float(dz[1]))/update_step*shorten_arrow,width=arrow_width,color='white',length_includes_head=True)
    #ax2.arrow(point[0],point[1],-float(dz[0]),-float(dz[1]),width=arrow_width,color='red')
    #ax2.arrow(point[0],point[1],-float(dz[0])-sign(-float(dz[0]))*arrow_width,-float(dz[1])-sign(-float(dz[1]))*arrow_width,width=arrow_width,color='red')
    #CS = ax2.contourf()

    ax2.scatter(point[0],point[1],s=30.,color='white',label='Before Parameter Update')
    #ax2.scatter(b,func(b),color='green',label='After Parameter Update')

    c = plt.Circle((point[0], point[1]), update_step, color='white', linestyle='--', linewidth=2,
                   label='Possible Values After Parameter Update', fill=False)
    ax2.add_patch(c)


    # Note that in the following, we explicitly pass in a subset of the contour
    # levels used for the filled contours.  Alternatively, we could pass in
    # additional levels to provide extra resolution, or leave out the *levels*
    # keyword argument to use all of the original levels.

    #CS2 = ax2.contour(CS, levels=CS.levels[::2], colors='r')

    #ax2.set_title('Nonsense (3 masked regions)')
    # ax2.set_xlabel('Model Parameter Number 1')
    # ax2.set_ylabel('Model Parameter Number 2')
    ax2.set_xlabel('Parameter 1')
    ax2.set_ylabel('Parameter 2')

# Make a colorbar for the ContourSet returned by the contourf call.
cbar = fig1.colorbar(CS,ax=axis)
cbar.ax.set_ylabel('Loss')
# Add the contour line levels to the colorbar
#cbar.add_lines(CS2)

# handles, labels = axis.flatten()[0].get_legend_handles_labels()
# fig1.legend(handles, labels, loc='lower right')#loc='upper center')

lines_labels = axis.flatten()[0].get_legend_handles_labels()
#lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
lines, labels = lines_labels
#fig1.legend(lines, labels, loc='upper center', ncol=4)

#plt.legend(lines, labels, loc = 'lower center', bbox_to_anchor = (0, -0.1, 1, 1),
#           bbox_transform = fig1.transFigure)
lines+=[arrow_grad,arrow_avggrad]
labels+=['Gradient Direction','Direction of the Lowest Average Gradient']

def make_legend_arrow(legend, orig_handle,
                      xdescent, ydescent,
                      width, height, fontsize):

    p = FancyArrow(0, 0.5*height, width, 0.,width=orig_handle._width*50., length_includes_head=True, head_width=0.75*height*orig_handle._width/0.04)
    return p

fig1.legend(lines, labels, loc = 'upper center', ncol=2, labelspacing=0.,handler_map={FancyArrow : HandlerPatch(patch_func=make_legend_arrow)})
# Finally, the legend (that maybe you'll customize differently)

if save_plot:
    plt.savefig(save_directory + 'Illustration_Multidimensional' + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)

plt.show()