from click import style
from matplotlib import markers
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from np_functions import *
from mpl_toolkits.mplot3d import Axes3D



def D3(ub=100,lb=-100,fun=None,s0=None,s11=None,s1=None,id=1):
# 画图的显示区间
    L_bound = lb
    H_bound = ub
    step_bound = 1
    X = np.arange(L_bound, H_bound, step_bound)
    Y = np.arange(L_bound, H_bound, step_bound)

    # 采样数据
    X, Y = np.meshgrid(X, Y)
    L, H = np.shape(X)
    x = np.zeros((1,2))
    fplot = np.zeros((L,H))
    for i in tqdm.tqdm(range(0,L)):
        for j in range(0,H):
            x[0,0] = X[i,j]
            x[0,1] = Y[i,j]
            fplot[i,j] = fun(x,0) # 采样第一个函数

    # 画出3D图
    fig = plt.figure(figsize=(8.6,8.6))
    #ax = fig.gca(projection='3d')
    ax = Axes3D(fig)
    low = fplot.min()
    ax.plot_surface(X,Y,fplot,rstride=1,cstride=1,cmap='rainbow',alpha=0.1)
    ax.tick_params(labelsize=16)
    # ax.contour(X,Y,fplot,zdir='z',offset=low,cmap='rainbow')
    # ax.scatter([10,-10],[10,-20],c='red',marker='*',cmap=plt.cm.Blues,s=10,label='red')

    ax.scatter(s0[:,1],s0[:,2],s0[:,0],c='blue',marker='v',linewidths=2,s=50,label='S%d'%(id-1),alpha=0.5)
    ax.scatter(s11[:,1],s11[:,2],s11[:,0],c='grey',marker='^',linewidths=3,s=80,label='S%d\''%(id-1),alpha=1)
    ax.scatter(s1[:,1],s1[:,2],s1[:,0],c='green',marker='o',linewidths=2,s=35,label='S%d'%(id),alpha=0.6)
    minindex=np.argmin(s1[:,0])
    ax.scatter(s1[minindex,1],s1[minindex,2],s1[minindex,0],c='red',linewidths=12,marker='*',s=60,label='Best',alpha=1)
    
    # ax.scatter(s1[minindex,1],s1[minindex,2],0,c='gold',linewidths=4,marker='o',cmap=plt.cm.Blues,s=20,label='Projection of the best point',alpha=1)
    
    ax.set_xlabel('\nx', color='black',fontsize=20)
    ax.set_ylabel('\ny', color='black',fontsize=20)
    ax.set_zlabel('\n\n\nfitness', color='black',fontsize=20)
    ax.legend(fontsize=20,markerscale=3)
    ax.view_init(elev=45,azim=27.5)
    # 不显示坐标轴
    # ax.grid(False)
    # ax.set_xticks([])
    # ax.set_yticks([])
    # ax.set_zticks([])
    # plt.axis('off')
    plt.savefig("landscape_%d.png"%(id))
    plt.savefig("landscape_%d.svg"%(id))
    # plt.show()
    return np.min(s1[:,0])


if __name__=='__main__':
    D3(fun=cec_fun1)