from scipy.optimize import fsolve
import numpy as np
# def relu(x):
#     return x*(x>0)

def neuron_output(x, para, k):
    sum = 0
    # for i in range(3):
    sum += para[6 + k] * relu(para[k] * x + para[3 + k])
    return sum


def relu(x):
    # return x*(x>0)

    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

def relu1(x):
    # return 1*(x>0)

    return 4*np.exp(2*x)/(np.exp(2*x)+1)**2



def get_y(xs):  # Function to fit
    # tmp =  0.6*relu(0.7*xs - 0.4) - 0.7*relu(xs) + relu(-0.5*xs + 0.7)
    # tmp=7*relu(10*xs+5)+5*relu(-8*xs+7)-9*relu(6*xs-8)
    tmp= -1*relu(xs + 0.4) - relu(xs) + relu(3 * (xs + 0.2))

    # tmp = 0.4 * relu(xs - 0.5) + 0.4 * relu(xs) + 0.5 * relu(xs + 0.6)
    # tmp = -0.4 * relu(xs - 0.5) - 0.4 * relu(xs) + 0.5 * relu(xs + 0.6)
    # tmp = 0.8 * relu(-xs) + 0.45 * relu(xs - 0.5) + 0.45 * relu(xs + 0.5)
    return tmp

m=1
#
# def relu(x):
#     return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
#     # return x*(x>0)
#
#
# def relu1(x):
#     return 4*np.exp(2*x)/(np.exp(2*x)+1)**2
#     # return 1*(x>0)


# def get_y(xs):  # Function to fit
#     tmp =  -0.4*relu(xs - 0.5) - 0.4*relu(xs) + 0.5*relu(xs + 0.6)
#     # tmp=7*relu(10*xs+5)+5*relu(-8*xs+7)-9*relu(6*xs-8)
#     return tmp


def model_output(x, para,m):
    sum = 0
    for i in range(m):
        sum += para[2*m + i] * relu(para[i] * x + para[m + i])
    return sum


def neuron_output(x, para, k,m):
    sum = 0
    # for i in range(3):
    sum += para[2*m + k] * relu(para[k] * x + para[m + k])
    return sum


def cal_a1(k, num, para,m):
    wk = para[k]
    bk = para[k + m]
    ak = para[k + 2*m]
    sum = 0
    for i in range(num):
        sum += relu(wk * x[i] + bk) * (model_output(x[i], para,m) - y[i])
    return sum / num


def cal_w1(k, num, para,m):
    wk = para[k]
    bk = para[k + m]
    ak = para[k + 2*m]
    sum = 0
    for i in range(num):
        sum += ak * relu1(wk * x[i] + bk) * (model_output(x[i], para,m) - y[i]) * x[i]
    return sum / num


def cal_b1(k, num, para,m):
    wk = para[k]
    bk = para[k + m]
    ak = para[k + 2*m]
    sum = 0
    for i in range(num):
        sum += ak * relu1(wk * x[i] + bk) * (model_output(x[i], para,m) - y[i])
    return sum / num

#
# def cal_c1(k, num, para):
#     # wk=para[k]
#     # bk=para[k+3]
#     # ak=para[k+6]
#     sum = 0
#     for i in range(num):
#         sum += (model_output(x[i], para) - y[i])
#     return sum / num


num = 20
num2 = 100
x = np.linspace(-1, 1, num=num, endpoint=True)
y = get_y(x)
x_test = np.linspace(-1.5, 1.5, num=num2, endpoint=True)


def func(para,m):
    a = []
    for i in range(m):
        a.append(cal_w1(i, num, para,m))
    for i in range(m):
        a.append(cal_b1(i, num, para,m))
    for i in range(m):
        a.append(cal_a1(i, num, para,m))
    # x, y = paramlist[0], paramlist[1]
    return np.array(a)

# x = np.linspace(-1, 1, num=20, endpoint=True)
path='/home/zhangzhongwang/data/loss_landscape/test85'
point_all=[]
for i in range(200):
    test=np.loadtxt('%s/3critical_%s.txt'%(path,i))
    for i in test:
        point_all.extend(i)
# point_all=np.loadtxt('%s/3critical.txt'%(path))
# loss_used1=np.loadtxt('/home/zhangzhongwang/data/saddle_points/test90/1/loss_used.txt')
# loss_used2=np.loadtxt('/home/zhangzhongwang/data/saddle_points/test90/2/loss_used.txt')
cri2=[]
cri1=[]
cri3=[]
cri0_new=[]
cri1_new=[]
cri2_new=[]
cri3_new=[]
crinum=[]
cri_str=[]
epi=1e-3
std_epi=5e-4
for i in range(len(point_all)):

    point=point_all[i]
    # for j in range(3):
    #     a=(point[j]**2+point[j+3]**2)**(1/2)
    #     point[j]=point[j]/a
    #     point[j+3]=point[j+3]/a
    #     point[j+6]=point[j+6]*a
    # print(point)
    # zero_points = [point[3] / point[0], point[4] / point[1], point[5] / point[2]]
    # delta_bw=[abs(np.sign(point[6])*point[0]-np.sign(point[7])*point[1])+abs(np.sign(point[6])*point[3]-np.sign(point[7])*point[4]), abs(np.sign(point[6])*point[0]-np.sign(point[8])*point[2])+abs(np.sign(point[6])*point[3]-np.sign(point[8])*point[5]),
    #           abs(np.sign(point[7])*point[1] - np.sign(point[8])*point[2]) + abs(np.sign(point[8])*point[5] - np.sign(point[7])*point[4])]
    #
    std = np.std(neuron_output(x,point, 0,m))
    #
    # loss = 0
    # for j in range(20):
    #     loss += (model_output(x[j], point, m) - y[j]) ** 2
    # loss = loss / 20
    # loss_all.append((loss))
    # flag = 0
    # for j in loss_used1:
    #     if abs(np.log10(j) - np.log10(loss)) < 0.00001:
    #         flag = 1
    #         break
    # if flag == 1:
    #     cri1_new.append(point)
    #     continue
    # for j in loss_used2:
    #     if abs(np.log10(j) - np.log10(loss)) < 0.00001:
    #         flag = 2
    #         break
    # if flag == 2:
    #     cri2_new.append(point)
    #     continue
    # if std[0]>1e-5 and std[1]>1e-5 and std[2]>1e-5:
    #     cri3_new.append(point)
    #     continue
    # cri_str.append(point)
    #
    # loss_used.append(loss)
    # index_used.append(i)
    # para_used.append(a)
    if std< std_epi:
        cri0_new.append(point)
    # elif std[0]<std_epi and std[1]>std_epi and std[2]>std_epi and delta_bw[2]>epi:
    #     cri2_new.append(point)
    # # elif std[0] < 1e-2 and std[1] > 1e-2 and std[2] > 1e-2 and delta_bw[2] < epi:
    # #     cri1_new.append(a)
    # elif std[0]>std_epi and std[1]>std_epi and std[2]<std_epi and delta_bw[0]>epi:
    #     cri2_new.append(point)
    # # elif std[0] > 1e-2 and std[1] > 1e-2 and std[2] < 1e-2 and delta_bw[0] < epi:
    # #     cri1_new.append(a)
    # elif std[0]>std_epi and std[1]<std_epi and std[2]>std_epi and delta_bw[1]>epi:
    #     cri2_new.append(point)
    # # elif std[0] > 1e-2 and std[1] < 1e-2 and std[2] > 1e-2 and delta_bw[1] < epi:
    # #     cri1_new.append(a)
    # # elif std[0] > 1e-2 and std[1] < 1e-2 and std[2] < 1e-2:
    # #     cri1_new.append(a)
    # # elif std[0] < 1e-2 and std[1] > 1e-2 and std[2] < 1e-2:
    # #     cri1_new.append(a)
    # # elif std[0] < 1e-2 and std[1] < 1e-2 and std[2] > 1e-2:
    # #     cri1_new.append(a)
    # elif std[0] > std_epi and std[1] >std_epi and std[2] >std_epi and delta_bw[1] > epi and delta_bw[0] > epi and delta_bw[2] > epi:
    #     cri3_new.append(point)
    # elif std[0] > std_epi and std[1] > std_epi and std[2] > std_epi and delta_bw[1] < epi and delta_bw[0] > epi and delta_bw[2] > epi:
    #     cri2_new.append(point)
    # elif std[0] > std_epi and std[1] > std_epi and std[2] > std_epi and delta_bw[1] > epi and delta_bw[0] < epi and delta_bw[2] > epi:
    #     cri2_new.append(point)
    # elif std[0] >std_epi and std[1] > std_epi and std[2] > std_epi and delta_bw[1] > epi and delta_bw[0] > epi and delta_bw[2] < epi:
    #     cri2_new.append(point)
    else:
        cri1_new.append(point)



    #
    #
    # for j in range(2):
    #     if abs(abs(point[0])-abs(point[j+1]))<epi:
    #         if abs(abs(point[3])-abs(point[j+4]))<epi:
    #             cri+=1
    # # print(cri)
    # if abs(abs(point[1]) - abs(point[2])) < epi:
    #     if abs(abs(point[4]) - abs(point[5])) < epi:
    #         cri += 1
    # for j in range(3):
    #     if delta_bw[j]<epi:
    #         cri+=1
    # # print(cri)
    # if cri==0:
    #     cri3.append(point)
    # elif cri==1:
    #     cri2.append(point)
    # else:
    #     cri1.append(point)
    # crinum.append(cri)
#
# for a in cri3:
#     flag=0
#
#     if std[0]<1e-2:
#         flag+=1
#     if std[1] < 1e-2:
#         flag += 1
#     if std[2]<1e-2:
#         flag+=1
#     if flag==0:
#         cri3_new.append(a)
#     elif flag==1:
#         cri2_new.append(a)
#     elif flag==2:
#         cri1_new.append(a)
#     else:
#         cri0_new.append(a)
#
#
#
# for a in cri2:
#     flag=0
#     std = [np.std(neuron_output(x, a, 0)), np.std(neuron_output(x, a, 1)), np.std(neuron_output(x, a, 2))]
#     if std[0]<1e-2:
#         flag+=1
#     if std[1] < 1e-2:
#         flag += 1
#     if std[2]<1e-2:
#         flag+=1
#     if flag==0:
#         cri2_new.append(a)
#     elif flag==1:
#         cri1_new.append(a)
#     elif flag==2:
#         cri1_new.append(a)
#     else:
#         cri0_new.append(a)
#
# for a in cri1:
#     flag=0
#     std = [np.std(neuron_output(x, a, 0)), np.std(neuron_output(x, a, 1)), np.std(neuron_output(x, a, 2))]
#     if std[0]<1e-2:
#         flag+=1
#     if std[1] < 1e-2:
#         flag += 1
#     if std[2]<1e-2:
#         flag+=1
#     if flag==0:
#         cri1_new.append(a)
#     elif flag==1:
#         cri1_new.append(a)
#     elif flag==2:
#         cri1_new.append(a)
#     else:
#         cri0_new.append(a)

# np.savetxt('%s/critical3.txt'%(path),cri3)
# np.savetxt('%s/critical2.txt'%(path),cri2)
# np.savetxt('%s/critical1.txt'%(path),cri1)
# np.savetxt('%s/critical_str_std_new.txt'%(path),cri_str)
# np.savetxt('%s/critical3_std_new.txt'%(path),cri3_new)
# np.savetxt('%s/critical2_std_new.txt'%(path),cri2_new)
np.savetxt('%s/critical1_std_new.txt'%(path),cri1_new)
np.savetxt('%s/critical0_std_new.txt'%(path),cri0_new)
# np.savetxt('%s/crinum.txt'%(path),crinum)