import numpy as np
import matplotlib.pyplot as plt

directory='plot_data/'
save_directory='saved/'
save_format='png'
save_plot=True
saved_dpi=1000
args={'figsize':(5,5),'dpi':100}

plt.figure(0,**args)

point_count=1000

func=lambda X: 1.1-np.sin(X)
dfunc=lambda X: -np.cos(X)
a=np.pi*0.3
b=np.pi*0.65

X=np.linspace(np.pi/4,np.pi*0.7,point_count)
Y=func(X)

Xgrad=np.linspace(a,b,point_count)
Ygrad=dfunc(a)*(Xgrad-a)+func(a)#*np.ones((point_count,))

Xavggrad=Xgrad
Yavggrad=(func(b)-func(a))/(b-a)*(Xavggrad-a)+func(a)#*np.ones((point_count,))

# print(X)
# print(Y)
plt.plot(X,Y,color='b',label='Loss in Terms of a Parameter')
plt.plot(Xgrad,Ygrad,color='r',label='Gradient Line',linestyle=':')
plt.plot(Xavggrad,Yavggrad,color='black',label='Average Gradient Line',linestyle='--')
plt.scatter(a,func(a),color='lightgray',label='Before Parameter Update')
plt.scatter(b,func(b),color='green',label='After Parameter Update')

plt.xlabel('Parameter')
plt.ylabel('Loss')
plt.legend()

if save_plot:
    plt.savefig(save_directory + 'Illustration' + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)


#plt.figure(0,**args)
args['figsize']=(args['figsize'][0]*2,args['figsize'][1])
fig,(ax1,ax2)=plt.subplots(ncols=2,**args)

#point_count=100

func=lambda X: 1.85-np.sin(X)
dfunc=lambda X: -np.cos(X)
a=np.pi*0.35
b=np.pi*0.95

x_argmin=np.pi/2

marker_size=90
plot_local_mins=True

X=np.linspace(np.pi*0.3,np.pi,point_count)
Y=func(X)

Xgrad=np.linspace(a,b,point_count)
Ygrad=dfunc(a)*(Xgrad-a)+func(a)#*np.ones((point_count,))

Xavggrad=Xgrad
Yavggrad=(func(b)-func(a))/(b-a)*(Xavggrad-a)+func(a)#*np.ones((point_count,))

# print(X)
# print(Y)
ax1.plot(X,Y,color='black',label='Loss in Terms of a Parameter')
ax1.plot(Xgrad,Ygrad,color='black',label='Gradient Line',linestyle=':')
ax1.plot(Xavggrad,Yavggrad,color='black',label='Average Gradient Line',linestyle='--')
ax1.scatter(a,func(a),color='black',label='Before Parameter Update',marker='s',s=marker_size)
ax1.scatter(b,func(b),color='black',label='After Parameter Update\nSuggested by the Gradient',s=marker_size)
if plot_local_mins:
    ax1.scatter(x_argmin, func(x_argmin), color='black', label='Local Minimum', marker='*', s=marker_size*2)

ax1.set_xlabel('Parameter')
ax1.set_ylabel('Loss')
ax1.legend()
#plt.ylim()

func=lambda X: 1.1-np.sin(X)
dfunc=lambda X: -np.cos(X)
a=np.pi*0.3
b=np.pi*0.65

x_argmin=1/2*np.pi

X=np.linspace(np.pi/4,np.pi*0.7,point_count)
Y=func(X)

Xgrad=np.linspace(a,b,point_count)
Ygrad=dfunc(a)*(Xgrad-a)+func(a)#*np.ones((point_count,))

Xavggrad=Xgrad
Yavggrad=(func(b)-func(a))/(b-a)*(Xavggrad-a)+func(a)#*np.ones((point_count,))

# print(X)
# print(Y)
ax2.plot(X,Y,color='black',label='Loss in Terms of a Parameter')
ax2.plot(Xgrad,Ygrad,color='black',label='Gradient Line',linestyle=':')
ax2.plot(Xavggrad,Yavggrad,color='black',label='Average Gradient Line',linestyle='--')
ax2.scatter(a,func(a),color='black',label='Before Parameter Update',marker='s',s=marker_size)
ax2.scatter(b,func(b),color='black',label='After Parameter Update\nSuggested by the Gradient',s=marker_size)
if plot_local_mins:
    ax2.scatter(x_argmin, func(x_argmin), color='black', label='Local Minimum', marker='*', s=marker_size*2)
ax2.set_xlabel('Parameter')
ax2.set_ylabel('Loss')
ax2.legend()



if save_plot:
    plt.savefig(save_directory + 'Illustration2' + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)
#plt.show()
args['figsize']=(args['figsize'][1],args['figsize'][1])
plt.figure(2,**args)
ax1=plt.gca()

func=lambda X: 1.85-np.sin(X)
dfunc=lambda X: -np.cos(X)
a=np.pi*0.35
b=np.pi*0.95

x_argmin=np.pi/2

marker_size=90
plot_local_mins=True

X=np.linspace(np.pi*0.3,np.pi,point_count)
Y=func(X)

Xgrad=np.linspace(a,b,point_count)
Ygrad=dfunc(a)*(Xgrad-a)+func(a)#*np.ones((point_count,))

Xavggrad=Xgrad
Yavggrad=(func(b)-func(a))/(b-a)*(Xavggrad-a)+func(a)#*np.ones((point_count,))

# print(X)
# print(Y)
ax1.plot(X,Y,color='black',label='Loss in Terms of a Parameter')
ax1.plot(Xgrad,Ygrad,color='black',label='Gradient Line',linestyle=':')
ax1.plot(Xavggrad,Yavggrad,color='black',label='Average Gradient Line',linestyle='--')
ax1.scatter(a,func(a),color='black',label='Before Parameter Update',marker='s',s=marker_size)
ax1.scatter(b,func(b),color='black',label='After Parameter Update\nSuggested by the Gradient',s=marker_size)
if plot_local_mins:
    ax1.scatter(x_argmin, func(x_argmin), color='black', label='Local Minimum', marker='*', s=marker_size*2)

ax1.set_xlabel('Parameter')
ax1.set_ylabel('Loss')
ax1.legend()
#plt.ylim()

if save_plot:
    plt.savefig(save_directory + 'Illustration2.1' + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)

plt.figure(3,**args)
ax2=plt.gca()

func=lambda X: 1.1-np.sin(X)
dfunc=lambda X: -np.cos(X)
a=np.pi*0.3
b=np.pi*0.65

x_argmin=1/2*np.pi

X=np.linspace(np.pi/4,np.pi*0.7,point_count)
Y=func(X)

Xgrad=np.linspace(a,b,point_count)
Ygrad=dfunc(a)*(Xgrad-a)+func(a)#*np.ones((point_count,))

Xavggrad=Xgrad
Yavggrad=(func(b)-func(a))/(b-a)*(Xavggrad-a)+func(a)#*np.ones((point_count,))

# print(X)
# print(Y)
ax2.plot(X,Y,color='black',label='Loss in Terms of a Parameter')
ax2.plot(Xgrad,Ygrad,color='black',label='Gradient Line',linestyle=':')
ax2.plot(Xavggrad,Yavggrad,color='black',label='Average Gradient Line',linestyle='--')
ax2.scatter(a,func(a),color='black',label='Before Parameter Update',marker='s',s=marker_size)
ax2.scatter(b,func(b),color='black',label='After Parameter Update\nSuggested by the Gradient',s=marker_size)
if plot_local_mins:
    ax2.scatter(x_argmin, func(x_argmin), color='black', label='Local Minimum', marker='*', s=marker_size*2)
ax2.set_xlabel('Parameter')
ax2.set_ylabel('Loss')
ax2.legend()

if save_plot:
    plt.savefig(save_directory + 'Illustration2.2' + "." + save_format, bbox_inches='tight',
                dpi=saved_dpi, format=save_format)

plt.show()