from scipy.optimize import fsolve
import numpy as np
import os, sys
import time
import pickle
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import math

matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import platform
import shutil

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
Leftp = 0.18
Bottomp = 0.18
Widthp = 0.88 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]


def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)


def save_fig(pltm, fntmp, fp=1, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        pltm.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()


def neuron_output(x, para, k,m):
    sum = 0
    # for i in range(3):
    sum += para[2*m + k] * relu(para[k] * x + para[m + k])
    return sum


def relu(x):
    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
    # return x*(x>0)


path='/home/dir/data/loss_landscape/test70/2'
# eig1=np.loadtxt('%s/eig_used1.txt'%(path))
# mis1=np.loadtxt('%s/mistake_used1.txt'%(path))
# for i in eig1:
#     i.sort()
# eig1 = np.array(eig1)

eig2_all=np.loadtxt('%s/eig_used2.txt'%(path))
mis2=np.loadtxt('%s/mistake_used2.txt'%(path))
# eig2=abs(eig2)
# for i in eig2:
#     i.sort()
# eig2 = np.array(eig2)
eig3_all=np.loadtxt('%s/eig_used3.txt'%(path))
mis3=np.loadtxt('%s/mistake_used3.txt'%(path))
# eig3=abs(eig3)
# for i in eig3:
#     i.sort()
# eig3=np.array(eig3)
eig4_all=np.loadtxt('%s/eig_used4.txt'%(path))
mis4=np.loadtxt('%s/mistake_used4.txt'%(path))
# eig4=abs(eig4)
# for i in eig4:
#     i.sort()
# eig4=np.array(eig4)
# np.savetxt('%s/test.txt'%(path),eig4)
# x=np.arange(len(eig_ori))
# print(x)
# print(eig2[:,0])
for num in range(2):
    minus_ind_all=[]
    minus_ind=[]
    eig2=eig2_all[num]
    eig3 = eig3_all[num]
    eig4 = eig4_all[num]
    for ind,i in enumerate(eig2):
        if i<0:
            minus_ind.append(ind)
            print(ind)
            eig2[ind]=-eig2[ind]
        if abs(i)<1e-40:
            eig2[ind]=1e-40
    minus_ind_all.append(minus_ind)
    minus_ind=[]    
    for ind,i in enumerate(eig3):
        if i<0:
            minus_ind.append(ind)
            eig3[ind]=-eig3[ind]
        if abs(i)<1e-40:
            eig3[ind]=1e-40
    minus_ind_all.append(minus_ind)
    minus_ind=[]
    for ind,i in enumerate(eig4):
        if i<0:
            minus_ind.append(ind)
            eig4[ind]=-eig4[ind]
        if abs(i)<1e-40:
            eig4[ind]=1e-40
    minus_ind_all.append(minus_ind)
    # print(eig4[0,:])


    point_size=50
    font_size=20

    fig, axes = plt.subplots(3, 1, figsize=(9, 6),sharey=True)
    # axes[0, 0].scatter([1,2,3], eig1, c='b', label='The abs eigenvalue of m=1')
    # axes[0, 0].set_yscale('log')
    plt.legend(fontsize=10)
    aa=axes[0].scatter([0,1,2,3,4,5], eig2, c='b', label='negative',s=point_size)
    aa=axes[0].scatter([0,1,2,3,4,5], eig2, c='r', label='positive',s=point_size)
    # axes[0, 1].plot(x, eig2[:,1], 'b', label='The second smallest eigenvalue of m=2')
    plt.legend([aa],['A'])
    axes[0].set_yscale('log')
    plt.legend(fontsize=10)
    axes[1].scatter([0,1,2,3,4,5,6,7,8], eig3, c='r', label='The positive eigenvalue of m=3',s=point_size)
    axes[1].scatter(minus_ind_all[1], eig3[minus_ind_all[1]], c='b', label='The negative eigenvalue of m=3',s=point_size)
    # axes[1, 0].plot(x, eig3[:,2], 'b', label='The third smallest eigenvalue of m=3')
    axes[1].set_yscale('log')
    plt.legend(fontsize=10)
    print(minus_ind_all[2])
    axes[2].scatter([0,1,2,3,4,5,6,7,8,9,10,11], eig4, c='r', label='The positive eigenvalue of m=4',s=point_size)
    axes[2].scatter(minus_ind_all[2], eig4[minus_ind_all[2]], c='b', label='The negative eigenvalue of m=4',s=point_size)
    # axes[1, 1].plot(x, eig4[:,3], 'b', label='The fourth smallest eigenvalue of m=4')
    # plt.suptitle('Comparison of Eigenvalue of Split Neurons' )
    axes[2].set_yscale('log')
    axes[0].legend(fontsize=20,loc='lower left')
    axes[2].legend_.remove()
    # axes[1].legend(fontsize=9,loc='lower right')
    # axes[0, 0].legend(fontsize=9,loc='lower right')
    # axes[2].legend(fontsize=9,loc='lower right')

    axes[0].tick_params(labelsize=18)
    axes[1].tick_params(labelsize=18)
    axes[2].tick_params(labelsize=18)
    # axes[1, 1].tick_params(labelsize=18)
    # axes[0, 1].set_ylabel('abs(eig)')
    axes[0].set_ylabel('abs(eig)', size=font_size)
    axes[1].set_ylabel('abs(eig)', size=font_size)
    axes[2].set_ylabel('abs(eig)', size=font_size)
    axes[2].set_xlabel('index', size=font_size)
    axes[0].text(1.9,1e-19,'m=2', size=font_size)
    axes[1].text(3,1e-19,'m=3', size=font_size)
    axes[2].text(4.2,1e-19,'m=4', size=font_size)

    axes[0].set_xticks([0,2,4])
    axes[1].set_xticks([0,4,8])

    axes[2].set_xticks([0,5,10])

    axes[0].set_yticks([1e-8,1e-26])
    axes[1].set_yticks([1e-8,1e-26])
    axes[2].set_yticks([1e-8,1e-26])

    axes[0].axhline(y=1e-11, color='black' , linestyle='--')
    axes[1].axhline(y=1e-11, color='black' , linestyle='--')
    axes[2].axhline(y=1e-11, color='black' , linestyle='--')
    # axes[0].annotate('y=1e-10',xy=(8,1e-10),xytext=(6,1e-8))

    # axes[1].set_xticks([0,4,8])

    # axes[2].set_xticks([0,5,10])

    # axes[1, 1].set_xticks([0,5,10])

    # axes[0].set_title('m=2', size=font_size)
    # axes[1].set_title('m=3', size=font_size)
    # axes[2].set_title('m=4', size=font_size)
    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    fntmp = '%s/eig_one_case%s' % (path,num)
    save_fig(plt, fntmp, ax=0, isax=0, iseps=0)

    #
    # fig, axes = plt.subplots(2, 2, figsize=(9, 6))
    # axes[0, 0].plot(x, mis_ori, 'b', label='The abs eigenvalue of m=1')
    # axes[0, 0].set_yscale('log')
    # plt.legend(fontsize=10)
    # aa=axes[0, 1].plot(x, mis2, 'b', label='The critical points error of m=2')
    # # axes[0, 1].plot(x, mis2, 'b', label='The second smallest eigenvalue of m=2')
    # plt.legend([aa],['A'])
    # axes[0, 1].set_yscale('log')
    # plt.legend(fontsize=10)
    # axes[1, 0].plot(x, mis3, 'b', label='The critical points error of m=3')
    # # axes[1, 0].plot(x, mis3, 'b', label='The third smallest eigenvalue of m=3')
    # axes[1, 0].set_yscale('log')
    # plt.legend(fontsize=10)
    # # axes[1, 1].plot(x, mis4, 'r', label='The third smallest eigenvalue of m=4')
    # axes[1, 1].plot(x, mis4, 'b', label='The critical points error of m=4')
    # plt.suptitle('Comparison of Degeneration Dimensions of Split Neurons' )
    # axes[1, 1].set_yscale('log')
    # axes[0, 1].legend(fontsize=9,loc='lower right')
    # axes[1, 1].legend(fontsize=9,loc='lower right')
    # axes[0, 0].legend(fontsize=9,loc='lower right')
    # axes[1, 0].legend(fontsize=9,loc='lower right')
    #
    #
    # axes[0, 1].set_ylabel('error')
    # axes[0, 0].set_ylabel('error')
    # axes[1, 1].set_ylabel('error')
    # axes[1, 0].set_ylabel('error')
    # plt.subplots_adjust(wspace=0.3, hspace=0.1)
    # fntmp = '%s/mis' % (path)
    # save_fig(plt, fntmp, ax=0, isax=0, iseps=0)