import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from sympy import *



# delta = 0.025
# x = y = np.arange(-3.0, 3.01, delta)
# X, Y = np.meshgrid(x, y)
# Z1 = np.exp(-X**2 - Y**2)
# Z2 = np.exp(-(X - 1)**2 - (Y - 1)**2)
# Z = (Z1 - Z2) * 2

#point=(0.28,-0.46)
point=(0.5,-0.7)

move=point

delta = 0.025
x = y = np.arange(-2.0, 2.01, delta)
X, Y = np.meshgrid(x+move[0], y+move[1])

Z1 = np.exp(-np.abs(X)**2 - np.abs(Y)**2)
Z2 = np.exp(0.5+0.5*(-np.abs((X + 0.5))**1.5 - np.abs((Y + 1.7))**1.5))
Z3 = np.exp(1.5*(-np.abs((X - 1))**2.5 - np.abs((Y - 1))**2.5))
Z4 = 1.7*np.exp(0.5*(-np.abs(X-2)**2 - np.abs(Y+1)**2))
Z5 = np.exp(1.5*(-np.abs((X - 0.75))**2 - np.abs((Y + 2.5))**2))
Z = Z2+Z1+Z3+Z4+Z5+0.3#(Z1 - Z2) * 2+

x = Symbol('x')
y=Symbol('y')
z=exp(-abs(x)**2 - abs(y)**2)+exp(0.5+0.5*(-np.abs((x+ 0.5))**1.5 - abs((y + 1.7))**1.5))+exp(1.5*(-abs((x- 1))**2.5 - abs((y - 1))**2.5))\
    +1.7*exp(0.5*(-abs(x-2)**2 - abs(y+1)**2))\
    +exp(1.5*(-abs((x - 0.75))**2 - abs((y + 2.5))**2))\
    +0.3
dz=z.diff(x,y)
print(dz)

# Z=np.empty_like(X)
# for iy, ix in np.ndindex(X.shape):
#     Z[ix,iy]=z.evalf(subs={x:X[ix,iy], y:Y[ix,iy]})


update_step=1
def inside_update_step_criterion(x,y,z):
    return (x-point[0])**2+(y-point[1])**2<=update_step**2

def minimum_with_criterion(X,Y,Z,criterion=lambda x,y,z:True):
    ind_of_min = [0, 0]
    min_value=np.inf
    for iy, ix in np.ndindex(X.shape):
        #min_val=np.inf

        if criterion(X[ix,iy],Y[ix,iy],Z[ix,iy]):

            if min_value>Z[ix,iy]:
                #print(Z[ix, iy])
                ind_of_min[0]=ix
                ind_of_min[1] = iy
                min_value=Z[ind_of_min[0],ind_of_min[1]]
    return ind_of_min


fig1, ax2 = plt.subplots(layout='constrained')
CS = ax2.contourf(X, Y, Z, 25)#, cmap=plt.cm.bone)

c=plt.Circle((point[0], point[1]), update_step,color='white',linestyle='--',linewidth=2,label='After Parameter Update', fill=False)
ax2.add_patch(c)


min_coords=minimum_with_criterion(X,Y,Z,inside_update_step_criterion)
#ax2.arrow(point[0],point[1],X[min_coords[0],min_coords[1]]-point[0],Y[min_coords[0],min_coords[1]]-point[1],width=0.02,color='black')
arrow_width=0.02
shorten_arrow=arrow_width*4
vec_len=((X[min_coords[0],min_coords[1]]-point[0])**2+(Y[min_coords[0],min_coords[1]]-point[1])**2)**0.5
ax2.add_patch(matplotlib.patches.FancyArrow(point[0],point[1],(X[min_coords[0],min_coords[1]]-point[0])-(X[min_coords[0],min_coords[1]]-point[0])/vec_len*shorten_arrow,(Y[min_coords[0],min_coords[1]]-point[1])-(Y[min_coords[0],min_coords[1]]-point[1])/vec_len*shorten_arrow,width=arrow_width,color='black'))
#derivative=diff(z).evalf(subs={x:point[0], y:point[1]})#dz(point[0],point[1])
#print(Symbol.diff(z(x,y),x))
#print(derivative)
def derivative(f,x,delta=0.0001):
    return (f(x+delta)-f(x))/delta

dz=[derivative(lambda _x:z.evalf(subs={x:_x, y:point[1]}),point[0]),derivative(lambda _x:z.evalf(subs={x:point[0], y:_x}),point[1])]
length=sqrt(dz[0]**2+dz[1]**2)
dz[0]=update_step/length*dz[0]
dz[1]=update_step/length*dz[1]
ax2.arrow(point[0],point[1],-float(dz[0])-(-float(dz[0]))/update_step*shorten_arrow,-float(dz[1])-(-float(dz[1]))/update_step*shorten_arrow,width=arrow_width,color='red')
#ax2.arrow(point[0],point[1],-float(dz[0]),-float(dz[1]),width=arrow_width,color='red')
#ax2.arrow(point[0],point[1],-float(dz[0])-sign(-float(dz[0]))*arrow_width,-float(dz[1])-sign(-float(dz[1]))*arrow_width,width=arrow_width,color='red')
#CS = ax2.contourf()

ax2.scatter(point[0],point[1],color='lightgray',label='Before Parameter Update')
#ax2.scatter(b,func(b),color='green',label='After Parameter Update')


# Note that in the following, we explicitly pass in a subset of the contour
# levels used for the filled contours.  Alternatively, we could pass in
# additional levels to provide extra resolution, or leave out the *levels*
# keyword argument to use all of the original levels.

#CS2 = ax2.contour(CS, levels=CS.levels[::2], colors='r')

ax2.set_title('Nonsense (3 masked regions)')
ax2.set_xlabel('Model Parameter Number 1')
ax2.set_ylabel('Model Parameter Number 2')

# Make a colorbar for the ContourSet returned by the contourf call.
cbar = fig1.colorbar(CS)
cbar.ax.set_ylabel('Loss')
# Add the contour line levels to the colorbar
#cbar.add_lines(CS2)

plt.show()