from scipy.optimize import fsolve
import numpy as np
import os, sys
import time
import pickle
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import math

matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import platform
import shutil

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
Leftp = 0.18
Bottomp = 0.18
Widthp = 0.88 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]


def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)


def save_fig(pltm, fntmp, fp=0, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        fp.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()


##### para=[w1,w2,w3,b1,b2,b3,a1,a2,a3,c] , c is the constant parameter




# def relu(x):
#     return x*(x>0)
#
#     # return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
#
# def relu1(x):
#     return 1*(x>0)
#
#     # return 4*np.exp(2*x)/(np.exp(2*x)+1)**2


def relu2(x):
    # return 0
    return (8*np.exp(2*x)-8*np.exp(4*x))/(np.exp(2*x)+1)**3
def get_y(xs):  # Function to fit
    # tmp =  0.6*relu(0.7*xs - 0.4) - 0.7*relu(xs) + relu(-0.5*xs + 0.7)
    # tmp=7*relu(10*xs+5)+5*relu(-8*xs+7)-9*relu(6*xs-8)
    # tmp= -1*relu(xs + 0.4) - relu(xs) + relu(3 * (xs + 0.2))
    tmp =   relu(xs) + relu(xs + 7) + relu((xs -7))
    # tmp = 0.4 * relu(xs - 0.5) + 0.4 * relu(xs) + 0.5 * relu(xs + 0.6)
    # tmp = -0.4 * relu(xs - 0.5) - 0.4 * relu(xs) + 0.5 * relu(xs + 0.6)
    # tmp = 0.8 * relu(-xs) + 0.45 * relu(xs - 0.5) + 0.45 * relu(xs + 0.5)
    return tmp



def relu(x):
    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
    # return x*(x>0)


def relu1(x):
    return 4*np.exp(2*x)/(np.exp(2*x)+1)**2
    # return 1*(x>0)

#
# def get_y(xs):  # Function to fit
#     tmp =  -0.4*relu(xs - 0.5) - 0.4*relu(xs) + 0.5*relu(xs + 0.6)
#     # tmp=7*relu(10*xs+5)+5*relu(-8*xs+7)-9*relu(6*xs-8)
#     return tmp


def model_output(x, para,m):
    sum = 0
    for i in range(m):
        sum += para[2*m + i] * relu(para[i] * x + para[m + i])
    return sum


def neuron_output(x, para, k,m):
    sum = 0
    # for i in range(3):
    sum += para[2*m + k] * relu(para[k] * x + para[m + k])
    return sum


def cal_a1(k, num, para,m):
    wk = para[k]
    bk = para[k + m]
    ak = para[k + 2*m]
    sum = 0
    for i in range(num):
        sum += relu(wk * x[i] + bk) * (model_output(x[i], para,m) - y[i])
    return sum / num


def cal_w1(k, num, para,m):
    wk = para[k]
    bk = para[k + m]
    ak = para[k + 2*m]
    sum = 0
    for i in range(num):
        sum += ak * relu1(wk * x[i] + bk) * (model_output(x[i], para,m) - y[i]) * x[i]
    return sum / num


def cal_b1(k, num, para,m):
    wk = para[k]
    bk = para[k + m]
    ak = para[k + 2*m]
    sum = 0
    for i in range(num):
        sum += ak * relu1(wk * x[i] + bk) * (model_output(x[i], para,m) - y[i])
    return sum / num


def cal_aa1(k,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    sum=0
    for i in range (num):
        sum+=relu(wk*x[i]+bk)**2
    return sum/num

def cal_aw1(k,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    sum=0
    for i in range (num):
        sum += ak * relu1(wk * x[i] + bk) * x[i] * relu(wk * x[i] + bk)
        # sum+=relu1(wk*x[i]+bk)*x[i]*(model_output(x[i],para)-y[i])+ak*relu1(wk*x[i]+bk)*x[i]*relu(wk*x[i]+bk)
    return sum/num

def cal_ab1(k,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    sum=0
    for i in range (num):
        sum += ak * relu1(wk * x[i] + bk)  * relu(wk * x[i] + bk)
        # sum+=relu1(wk*x[i]+bk)*(model_output(x[i],para)-y[i])+ak*relu1(wk*x[i]+bk)*relu(wk*x[i]+bk)
    return sum/num

def cal_ww1(k,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    sum=0
    # print(type(ak))
    # print(type(relu2(wk*x[1]+bk)))
    # print(type(bk))
    # print(type(wk))
    # print(type(x[1]**2))
    # print(type((model_output(x[1],para))))
    # sum_all=[]
    for i in range (num):
        sum+=ak*relu2(wk*x[i]+bk)*x[i]**2*(model_output(x[i],para,m)-y[i])+ak**2*(relu1(wk*x[i]+bk))**2*x[i]**2
        # sum_all.append(sum)
    # print(sum_all)
    return sum/num

def cal_wb1(k,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    sum=0
    for i in range (num):
        sum+=ak*relu2(wk*x[i]+bk)*x[i]*(model_output(x[i],para,m)-y[i])+ak**2*(relu1(wk*x[i]+bk))**2*x[i]
    return sum/num

def cal_bb1(k,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    sum=0
    for i in range (num):
        sum+=ak*relu2(wk*x[i]+bk)*(model_output(x[i],para,m)-y[i])+ak**2*(relu1(wk*x[i]+bk))**2
    return sum/num

def cal_aa2(k,k2,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    wk2=para[k2]
    bk2=para[k2+m]
    ak2=para[k2+2*m]
    sum=0
    for i in range (num):
        sum+=relu(wk*x[i]+bk)*relu(wk2*x[i]+bk2)
    return sum/num

def cal_aw2(k,k2,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    wk2=para[k2]
    bk2=para[k2+m]
    ak2=para[k2+2*m]
    sum=0
    for i in range (num):
        sum+=relu(wk*x[i]+bk)*ak2*relu1(wk2*x[i]+bk2)*x[i]
    return sum/num

def cal_ab2(k,k2,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    wk2=para[k2]
    bk2=para[k2+m]
    ak2=para[k2+2*m]
    sum=0
    for i in range (num):
        sum+=relu(wk*x[i]+bk)*ak2*relu1(wk2*x[i]+bk2)
    return sum/num

def cal_ww2(k,k2,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    wk2=para[k2]
    bk2=para[k2+m]
    ak2=para[k2+2*m]
    sum=0
    for i in range (num):
        sum+=ak*relu1(wk*x[i]+bk)*ak2*relu1(wk2*x[i]+bk2)*x[i]**2
    return sum/num

def cal_wb2(k,k2,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    wk2=para[k2]
    bk2=para[k2+m]
    ak2=para[k2+2*m]
    sum=0
    for i in range (num):
        sum+=ak*relu1(wk*x[i]+bk)*ak2*relu1(wk2*x[i]+bk2)*x[i]
    return sum/num

def cal_bb2(k,k2,num,para,m):
    wk=para[k]
    bk=para[k+m]
    ak=para[k+2*m]
    wk2=para[k2]
    bk2=para[k2+m]
    ak2=para[k2+2*m]
    sum=0
    for i in range (num):
        sum+=ak*relu1(wk*x[i]+bk)*ak2*relu1(wk2*x[i]+bk2)
    return sum/num


#
# def cal_c1(k, num, para):
#     # wk=para[k]
#     # bk=para[k+3]
#     # ak=para[k+6]
#     sum = 0
#     for i in range(num):
#         sum += (model_output(x[i], para) - y[i])
#     return sum / num

path = '/home/zhangzhongwang/data/loss_landscape/test70'
mkdir(path)
num = 50
num2 = 100
x = np.linspace(-15, 15, num=num, endpoint=True)
y = get_y(x)
x_test = np.linspace(-20, 20, num=num2, endpoint=True)


def func(para,m):
    a = []
    for i in range(m):
        a.append(cal_w1(i, num, para,m))
    for i in range(m):
        a.append(cal_b1(i, num, para,m))
    for i in range(m):
        a.append(cal_a1(i, num, para,m))
    # x, y = paramlist[0], paramlist[1]
    return np.array(a)

m_ori=2
for numm in range(3):
    m=m_ori+numm
    lst = []
    delta_used=[]
    loss_all=[]
    std_all=[]
    save_path='%s/%s'%(path,m_ori)
    mkdir(save_path)
    mkdir('%s/output_all' % (save_path))
    mkdir('%s/output_all_3.5' % (save_path))
    critical3=np.loadtxt('/home/zhangzhongwang/data/loss_landscape/test69/6.0/2/3critical.txt')
    # critical3=[]
    # for i in range(50):
    #     test=np.loadtxt('%s/3critical_%s.txt'%(path,i))
    #     critical3.extend(test)
    # critical3=np.loadtxt('%s/3critical.txt'%(path))
    critical3_new=[]
    delta_bw_all=[]
    loss_used=[]
    index_used=[]
    para_used=[]
    path_norm_used=[]
    path_norm2_used=[]
    std_used=[]
    delta_bw_used=[]
    eig_used=[]
    mistake_used=[]
    critical3=[critical3[0],critical3[2]]
    # eig=np.loadtxt('/home/zhangzhongwang/data/loss_landscape/test85/1/eigenvalue/eigvalue_1.txt')
    for i, a in enumerate(critical3):
        if numm==0:
            print(i)
            para_new = [a[0],a[1],a[2],a[3],a[4],a[5]]
        if numm==1:

            print(i)
            para_new = [a[0],a[0],a[1],a[2],a[2],a[3],1/2*a[4],1/2*a[4],a[5]]

            # for k in range(m):
            #     para_new.append(a[0])
            # for k in range(m):
            #     para_new.append(a[1])
            # for k in range(m):
            #     para_new.append(1/m*a[2])
        # std=np.std(neuron_output(x, a, 0,m))
        if numm==2:
            print(i)
            para_new = [a[0],a[0],a[1],a[1],a[2],a[2],a[3],a[3],1/2*a[4],1/2*a[4],1/2*a[5],1/2*a[5]]
        mistake=sum(abs(func(para_new,m)))
        mistake_used.append(mistake)
        para = para_new
        A = []
        ori = []
        for j in range(m):
            wei1 = para_new[j]
            bias = para_new[j + m]
            wei2 = para_new[j + 2 * m]
            # wei = wei1 / math.sqrt(wei1 ** 2 + bias ** 2)
            # # print(wei1)
            # bia = bias / math.sqrt(wei1 ** 2 + bias ** 2)
            # ori.append(np.sign(bia) * math.acos(wei))
            A.append(math.sqrt(wei2**2+wei1 ** 2 + bias ** 2))
        pathnorm = sum(A)
        path_norm_used.append(pathnorm)

        # para_new=a
        # # para_new=[a[0],a[0],a[1],a[1],1/3*a[2],1/3*a[2],1/3*a[2]]
        # para_new=np.array(para_new,dtype=np.float64)
        loss=0
        for j in range(50):
            loss+=(model_output(x[j],para_new,m)-y[j])**2
        loss=loss/50
        loss_used.append(loss)
        critical3_new.append(para_new)
        # g = cal_g(1, num, para_new)
        # print(g)
        # print(relu2(w1*x[i]+bk))
        H=np.zeros((3*m,3*m))
        for k in range(m):
            for k2 in range(m):
                if k != k2:
                    H[k][k2]=cal_ww2(k,k2,num,para_new,m)
                    H[k+m][k2]=cal_wb2(k,k2,num,para_new,m)
                    H[k+2*m][k2]=cal_aw2(k,k2,num,para_new,m)
                    H[k][k2+m] = cal_wb2(k, k2, num, para_new,m)
                    H[k+m][k2+m] = cal_bb2(k, k2, num, para_new,m)
                    H[k+2*m][k2+m] = cal_ab2(k, k2, num, para_new,m)
                    H[k][k2+2*m] = cal_aw2(k, k2, num, para_new,m)
                    H[k+m][k2+2*m] = cal_ab2(k, k2, num, para_new,m)
                    H[k+2*m][k2+2*m] = cal_aa2(k, k2, num, para_new,m)
                else:
                    H[k][k]=cal_ww1(k,num,para_new,m)
                    H[k+m][k]=cal_wb1(k,num,para_new,m)
                    H[k+2*m][k]=cal_aw1(k,num,para_new,m)
                    H[k][k+m] = cal_wb1(k, num, para_new,m)
                    H[k+m][k+m] = cal_bb1(k, num, para_new,m)
                    H[k+2*m][k+m] = cal_ab1(k, num, para_new,m)
                    H[k][k+2*m] = cal_aw1(k, num, para_new,m)
                    H[k+m][k+2*m] = cal_ab1(k, num, para_new,m)
                    H[k+2*m][k+2*m] = cal_aa1(k, num, para_new,m)

        for k in range(3 * m):
            for k2 in range(3 * m):
                if k2 > k:
                    H[k][k2] = H[k2][k]

        e = np.linalg.eig(H)
        # eig_used.append(e[0])
        eig = e[0].real
        eig_used.append(eig)
        # print(e[0])
        # e0=np.array(e[0]).tolist()
        # print(e0)
        # print(e0.index(min(e[0])))
        #
        # print(ind)
        # break
        # min_eig.append(min(abs(e[0])))
        #
        # if loss<10**(-3.5):
        #     fig, axes = plt.subplots(2, 2, figsize=(9, 6), sharex=True, sharey=True)
        #     # plt.figure()
        #     # ax = plt.gca()
        #     axes[ 0,0].plot(x_test, model_output(x_test, a,m), 'r-', label='Test')
        #     axes[0,0].plot(x, y, 'b*', label='True')
        #     # plt.legend(fontsize=18)
        #     # fntmp = '%s/output_%s' % (path, i)
        #     # save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
        #
        #     # plt.figure()
        #     # ax = plt.gca()
        #     axes[0, 1].plot(x_test, neuron_output(x_test, a, 0,m), 'r-', label='0')
        #     axes[ 0,1].plot(x, y, 'b*', label='True')
        #
        #     axes[1, 0].plot(x_test, neuron_output(x_test, a, 1,m), 'r-', label='1')
        #     axes[1, 0].plot(x, y, 'b*', label='True')
        #
        #     axes[1, 1].plot(x_test, neuron_output(x_test, a, 2,m), 'r-', label='2')
        #     axes[1, 1].plot(x, y, 'b*', label='True')
        #     plt.legend(fontsize=18)
        #     fntmp = '%s/output_all_3.5/%s' % (save_path, i)
        #     save_fig(plt, fntmp, ax=0, isax=0, iseps=0)

        # if i % 50 == 0:




        plt.figure()
        ax = plt.gca()
        # axes[ 1].plot(x_test, neuron_output(x_test, a, 0,m), 'r-', label='0')
        # axes[1].plot(x, y, 'b*', label='True')
        plt.plot(x_test, model_output(x_test, para_new, m), 'r-', label='Test')
        plt.plot(x, y, 'b*', label='True')
        # axes[2].plot(x_test, neuron_output(x_test, a, 1,m), 'r-', label='1')
        # axes[2].plot(x, y, 'b*', label='True')

        # axes[1, 1].plot(x_test, neuron_output(x_test, a, 2,m), 'r-', label='2')
        # axes[1, 1].plot(x, y, 'b*', label='True')
        plt.title('m=%s' % (m_ori))
        # plt.legend(fontsize=18)
        fntmp = '%s/output_all/%s' % (save_path, np.log10(eig[i]))
        save_fig(plt, fntmp, ax=0, isax=0, iseps=0)
        #
        # fig, axes = plt.subplots(1, 2, figsize=(9, 6))
        # # plt.figure()
        # # ax = plt.gca()
        # axes[ 0].plot(x_test, model_output(x_test, a,m), 'r-', label='Test')
        # axes[ 0].plot(x, y, 'b*', label='True')
        # # plt.legend(fontsize=18)
        # # fntmp = '%s/output_%s' % (path, i)
        # # save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
        #
        # # plt.figure()
        # # ax = plt.gca()
        # axes[ 1].plot(x_test, neuron_output(x_test, a, 0,m), 'r-', label='0')
        # axes[1].plot(x, y, 'b*', label='True')
        # #
        # # axes[2].plot(x_test, neuron_output(x_test, a, 1,m), 'r-', label='1')
        # # axes[2].plot(x, y, 'b*', label='True')
        #
        # # axes[1, 1].plot(x_test, neuron_output(x_test, a, 2,m), 'r-', label='2')
        # # axes[1, 1].plot(x, y, 'b*', label='True')
        # plt.title('%s,%s' % (pathnorm2, np.log10(eig[i])))
        # plt.legend(fontsize=18)
        # fntmp = '%s/output_all/%s' % (save_path, np.log10(eig[i]))
        # save_fig(plt, fntmp, ax=0, isax=0, iseps=0)
    # np.savetxt('%s/std.txt' % (save_path), std_all)
    # # np.savetxt('%s/delta.txt' % (save_path), delta_all)
    # np.savetxt('%s/loss.txt' % (save_path), loss_all)
    np.savetxt('%s/critical3_new%s.txt' % (save_path,m), critical3_new)
    # np.savetxt('%s/delta_bw.txt' % (save_path), delta_bw_all)
    np.savetxt('%s/loss_used%s.txt' % (save_path,m), loss_used)
    # np.savetxt('%s/index_used.txt' % (save_path), index_used)
    np.savetxt('%s/para_used%s.txt' % (save_path,m), para_used)
    np.savetxt('%s/pathnorm_used%s.txt' % (save_path,m), path_norm_used)
    # np.savetxt('%s/pathnorm2_used.txt' % (save_path), path_norm2_used)
    # np.savetxt('%s/delta_used.txt' % (save_path), delta_used)
    # np.savetxt('%s/delta_bw_used.txt' % (save_path), delta_bw_used)
    # np.savetxt('%s/std_used.txt' % (save_path), std_used)
    np.savetxt('%s/mistake_used%s.txt' % (save_path,m), mistake_used)
    np.savetxt('%s/eig_used%s.txt' % (save_path,m), eig_used)

    # print(a)
    # print(func(a))

