## this script holds the implementation of the test multi-modal obj functions
## source: http://infinity77.net/global_optimization/test_functions.html#test-functions-index
## the value landscape is visualized as a sanity check
## colormap choice is coolwarm, following the advice given at https://www.kennethmoreland.com/color-maps/
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np


def hard1d1(x):
    """ 
        reference: https://machinelearningmastery.com/1d-test-functions-for-function-optimization/
        domain: [-7.5, 7.5]
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.888315 at x≈-6.21731
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.11909 at x≈-4.1966
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.7283 at x≈-2.29609
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.48843 at x≈-0.548883
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.0135205 at x≈1.39826
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.19992 at x≈3.38725
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.8996 at x≈5.14574
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.316996 at x≈7.00015

    """
    def objective(x):
	    return np.sin(x) + np.sin((10.0 / 3.0) * x)
    return np.array([objective(sample[0]) for sample in x])

def hard2d1(x):
    """ 
        add a second dimension to hard1d1
        reference: https://machinelearningmastery.com/1d-test-functions-for-function-optimization/
        domain: [-7.5, 7.5]
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.888315 at x≈-6.21731
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.11909 at x≈-4.1966
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.7283 at x≈-2.29609
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.48843 at x≈-0.548883
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.0135205 at x≈1.39826
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.19992 at x≈3.38725
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-1.8996 at x≈5.14574
        min{sin(x) + sin((10 x)/3)|-7.5<=x<=7.5}≈-0.316996 at x≈7.00015

    """
    def objective(vec):
        x = vec[0]
        y = vec[1]
        fx = np.sin(x) + np.sin((10.0 / 3.0) * x)
        fy = np.cos(x) + np.cos((10.0 / 3.0) * x)
        return fx+fy
 
    return np.array([objective(sample) for sample in x])


def hard1d2(x):
    """ local optimum from https://www.wolframalpha.com/input/?i=local+minimum+of+-%281.4+-+3x%29+*+sin%2816x%29%2C+x+in+%5B0%2C+1.2%5D 
        min{-(1.4 - 3 x) sin(16 x)|0<=x<=1.2}≈-1.12098 at x≈0.0879524
        min{-(1.4 - 3 x) sin(16 x)|0<=x<=1.2}≈-0.0620975 at x≈0.427608
        min{-(1.4 - 3 x) sin(16 x)|0<=x<=1.2}≈-0.686544 at x≈0.703359
        min{-(1.4 - 3 x) sin(16 x)|0<=x<=1.2}≈-1.84923 at x≈1.08621
    """
    def objective(x):
        """ [0, 1.2] """
        # x = x - 0.07
        # return -(1.4 - 3.0 * x) * np.sin(18.0 * x)
        return -(1.4 - 3.0 * x) * np.sin(16.0 * x)
    return np.array([objective(sample[0]) for sample in x])

def levy05(x):
    def levy_single(vec):
        a =  np.sum([i*np.cos((i-1)*vec[0]+i) for i in range(1,6)])
        b =  np.sum([j*np.cos((j+1)*vec[1]+j) for j in range(1,6)])
        c = (vec[0]+1.42513)**2
        d = (vec[1]+0.80032)**2
        return a*b+c+d
    return np.array([levy_single(sample) for sample in x])

def langermann(x):
    def langermann_single(vec):
        c = np.array([1, 2, 5, 2, 3])
        A = np.array([[3, 5], [5, 2], [2, 1], [1, 4], [7, 9]])
        return np.sum([c[i] * np.exp(-1/np.pi *np.sum((vec-A[i])**2)) * np.cos(np.pi *np.sum((vec-A[i])**2)) for i in range(5)])
    return np.array([langermann_single(sample) for sample in x])

def ripple01(x):
    def ripple_single(vec):
        res= np.sum([-np.exp(-2*np.log(2)*((xi-0.1)/0.8)**2) * ( np.sin(5*np.pi*xi)**6+0.1*np.cos(500*np.pi*xi)**2 )  for xi in vec]  )
        return res
    
    return np.array([ripple_single(sample) for sample in x])

def price02(x):
    def price_single(vec):
        #return 1+np.sin(vec[0])**2 + np.sin(vec[1])**2 - 0.1*np.exp(-np.sum(vec**2))
        return np.sum(np.sin(vec)**2) - 0.1*np.exp(-np.sum(vec**2)) + 1
    return np.array([price_single(sample) for sample in x])

def synt(x):
    """ x is an np.array of shape (n, d) 
    
        output shape (n, 1) : n is the number of elements
    """
    #n, d = x.shape
    def synt_single(vec):
        # evaluate a single sample vec
        return np.mean([0.25* xi**4 - 2*xi**2 for xi in vec]) + 5
    return np.array([synt_single(sample) for sample in x])

def styblinski_tang(x):
    """ Global min: f(x*) = -39.16599*dim, at x*=(-2.903534, ..., -2.903534) 

        Reference: http://infinity77.net/global_optimization/test_functions.html#test-functions-index
        
        4 local optima found by grid search
        -64.2 @ (2.904, -2.747), 
        -78.33233 @ (-2.904, -2.904) 
        -64.2 @ (-2.904, 2.747)
        -50.06 @ (2.747, 2.747)

    """
    #n, d = x.shape
    def styblinski_tang_single(vec):
        return 0.5 * np.sum([xi**4 -16*xi**2 + 5*xi for xi in vec])
    # res = []
    # for sample in x:
    #     res.append(styblinski_tang_single(sample))
    # arr_res = np.array(res)
    # return arr_res
    return np.array([styblinski_tang_single(sample) for sample in x]) 


if __name__ == "__main__":
    # BEGIN: input arguments
    mode = 7 # 1 for synt function, 2 for styblinski_tang function, 3 for ripple01
    render = 1 #True #False
    # End: input arguments
    func_name = {1: "synt", 2:"styb_tang", 3:"ripple01", 4: "price02", 5:"levy05", 6:"langermann", 7:"hard1d1", 8:"hard1d2", 9:"hard2d1"}
    n = 200
    #print('ripple min', ripple01([[0.1, 0.1]]))
    #assert False
    if mode == 1 or mode==2:
        x = np.array([np.linspace(-5,5,n), np.linspace(-5,5,n)])
    elif mode == 3:
        x = np.array([np.linspace(0,1,n), np.linspace(0,1,n)])
    elif mode == 4 or mode ==5: 
        x = np.array([np.linspace(-10,10,n), np.linspace(-10,10,n)])
    elif mode == 6:
        x = np.array([np.linspace(0,10,n), np.linspace(0,10,n)])
    elif mode == 7 or mode == 9:
        x = np.array([np.linspace(-7.5,7.5,n)])
    elif mode == 8:
        x = np.array([np.linspace(0,1.2,n)])

    if mode < 7 or mode == 9:
        xx, yy = np.meshgrid(x[0, :], x[1, :], sparse=False) # this is xy indexing xx[y, x], yy[y, x]
        x_to_eval = np.dstack((xx,yy)).reshape([-1, 2]) 
    else:
        x_to_eval = x.reshape(n,1)
    
    if mode == 1:
        z = synt(x_to_eval).reshape(n,n)
    elif mode == 2:
        z = styblinski_tang(x_to_eval).reshape(n,n) 
    elif mode == 3:
        z = ripple01(x_to_eval).reshape(n,n) 
    elif mode == 4 :
        z = price02(x_to_eval).reshape(n,n) 
    elif mode ==5:
        z = levy05(x_to_eval).reshape(n,n) 
    elif mode ==6:
        z = langermann(x_to_eval).reshape(n,n) 
    elif mode ==7:
        z = hard1d1(x_to_eval) 
    elif mode ==8:
        z = hard1d2(x_to_eval) 
    elif mode ==9:
        z = hard2d1(x_to_eval) 

    if render:
        # 3d plot
        fig = plt.figure()
        if mode < 7 or mode == 9:
            ax = fig.add_subplot(111, projection='3d')
            sc = ax.plot_surface(xx, yy, z, cmap=cm.coolwarm)
            # Add a color bar which maps values to colors.
            fig.colorbar(sc, shrink=0.5, aspect=5)
            #fig.savefig(f'img/landscape_styb_tang.pdf', bbox_inches='tight', dpi=600)
            fig.savefig(f'img/landscape_{func_name[mode]}.pdf', bbox_inches='tight', dpi=600)
            # contour plot
            fig3 = plt.figure()
            ax3 = fig3.add_subplot(111)
            sc3 = ax3.contourf(xx, yy, z, cmap=cm.coolwarm)
            #sc3 = ax3.contour(xx, yy, z, cmap=cm.coolwarm)
            # Add a color bar which maps values to colors.
            fig3.colorbar(sc3, shrink=0.5, aspect=5)
            #fig3.savefig(f'img/contour_styb_tang.pdf', bbox_inches='tight', dpi=600)
            fig3.savefig(f'img/contour_{func_name[mode]}.pdf', bbox_inches='tight', dpi=600)
            #plt.show()
        else:
            fig = plt.figure()
            fontsize = 14
            ax = fig.add_subplot(111)
            ax.plot(x_to_eval, z)
            ax.set_xlabel('x', fontsize=fontsize)
            ax.set_ylabel('f(x)', fontsize=fontsize)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            plt.grid(linestyle='--')
            color = ["#ff6eb4", "#ff2500"]
            if mode == 8:
                # for 1d2
                ax.plot([1.086, 1.086], [-1.849, 1.4], transform=ax.transData, ls='--', color=color[1], alpha=0.5)
                ax.set_ylim([-2.2, 1.4])
                local_optima = [[0.088, -1.121], [0.427, -0.062], [0.703, -0.687], [1.086, -1.849]]
            else:
                # for 1d1
                ax.plot([5.1457, 5.1457], [-1.8996, 2], transform=ax.transData, ls='--', color=color[1], alpha=0.5)
                ax.set_ylim([-2.2, 2])
                local_optima = [[-6.217, -.888], [-4.197, -0.119], [-2.296, -1.728], [-0.549, -1.488], 
                               [1.398,-0.014], [3.387, -1.200],  [7, -0.317], [5.146, -1.900]]
            text_font = {'color': '#ff2500', 'size':12}
            for i, val in enumerate(local_optima):
                if i != len(local_optima)-1:
                    dot_color = color[0]
                    text_color = 'k'
                    ax.annotate(r'{}@x={}'.format(val[1], val[0]), xy=(val[0], val[1]), 
                           xytext=(val[0]-0.3, val[1]-0.2), color=text_color,fontsize=fontsize-2 )
                else:
                    dot_color = color[1]
                    text_color = dot_color
                    ax.annotate(r'{}@x*={}'.format(val[1], val[0]), xy=(val[0], val[1]), 
                           xytext=(val[0]-0.3, val[1]-0.2), color=text_color, fontsize=fontsize-2)
                ax.plot([val[0]], [val[1]], 'o', color=dot_color, markersize=4)
            fig.savefig(f'img/landscape_{func_name[mode]}.pdf', bbox_inches='tight', dpi=600)
            #fig.savefig(f'img/landscape_{func_name[mode]}.png', bbox_inches='tight', dpi=600)

    if mode == 1: # synt function
        th = 2.0
        # next show the f<2.0 figure
        # turn the z value to binary, true: blue, false: white
    elif mode==2:
        th = -60 # typo on the GACEM paper on the threshold and the function
    elif mode==3:
        th = -1.5
    else:
        th = 1.5

    # local optima or suboptimal points
    # th = -30
    # i, j = np.nonzero(z<th)
    # print(f'find the coordinates with the values smaller than {th}')
    # for ii,jj in zip(i,j):  # print the function eval
    #     print(f'the local minima {z[ii,jj]} is at {x[0,jj],x[1,ii]}') 
        
    if 0: #render:
        fig2, ax2 = plt.subplots()
        ax2.scatter(x[0,i], x[1,j]) # use i,j to index into the two axis: x1 and x2
        ax2.set_xlim([-5,5])
        ax2.set_ylim([-5,5])
        plt.show()

    # analyze the global minima
    min_cost = np.min(z)
    print(f'min of the cost function is {min_cost:.5f}')
    min_indices = np.nonzero(z==min_cost) # find minima
    print(f'min index is {min_indices}') # the min_indices[0] indexes y and min_indices[1] indexes x
    flattened_indices = np.ravel_multi_index(min_indices, z.shape)
    print(f'min flattend index is {flattened_indices}')
    print(f'number of global optima {len(flattened_indices)}')
    print(f'all x,y values at the global optima from the flattened index: {x_to_eval[flattened_indices]}')
    if mode < 7:
        for ii,jj in zip(min_indices[0],min_indices[1]):  # print the (x,y) coordinates at the optima
            print(f"the x,y values at the global optima: {x[0,jj],x[1,ii]}") # something is wrong here on the order
    


