import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import os,sys


def plot1D_AE(fdir, z, y_test, params, dataset,pmin=-1000):
    '''
    i_n = np.where(y_test == 1)
    z_n = np.squeeze(z[i_n])
    i_p = np.where(y_test == 0)
    z_p = np.squeeze(z[i_p])

    i_m = np.where(y_test == 2)
    z_m = np.squeeze(z[i_m])
    '''
    z = z.transpose()
    if pmin == -1000:
        pmin = z.min()
    pmax = z.max()
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    pnum = z.shape[0]
    # pnum= 2000
    # ax.hist(z[:pnum], bins=50)
    # ax.hist([z_n[:pnum], z_p[:pnum]], bins=50, normed=True, color=['red', 'blue'], label=['anomaly', 'ordinally'], histtype='bar', stacked=True)
    #ax.hist([z_n[:pnum], z_p[:pnum], z_m[:pnum]],bins=600, normed=False, color=['red', 'blue','green'], label=['z1', 'z2','z3'],
    #        histtype='bar', stacked=True)
    #ax.hist([z[0], z[1], z[2]],bins=600, normed=False, color=['red', 'blue','green'], label=['z1', 'z2','z3'],
    #        histtype='bar', stacked=True)
    ax.hist([z[0], z[1], z[2]],bins=600, normed=False, color=['red', 'blue','green'], label=['z1', 'z2','z3'],
            histtype='bar', stacked=True, range=(pmin, pmax))

    # ax.hist([x1, x2, x3], bins=10, normed=True, color=['red', 'blue', 'green'], label=['x1', 'x2', 'x3'], histtype='bar', stacked=True)
    ax.legend(loc='upper left')
    # fig.show()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)


def plotline(fdir, z, y_test, params, dataset):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    x = np.arange(0, z.shape[1])
    ax.plot(x, z[0] ,label = 'predict1')
    ax.plot(x, y_test[0] ,label = 'gt1')

    ax.plot(x, z[1] ,label = 'predict2')
    ax.plot(x, y_test[1] ,label = 'gt2')

    #ax.plot(x, z[2] ,label = 'predict3')
    #ax.plot(x, y_test[2] ,label = 'gt3')

    ax.legend(loc='upper left')
    # fig.show()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)


def plot1D(fdir, z, y_test, params, dataset):
    i_n = np.where(y_test == 1)
    z_n = np.squeeze(z[i_n])
    print(i_n)
    print(z_n.shape)
    i_p = np.where(y_test == 0)
    z_p = np.squeeze(z[i_p])
    print(z_p.shape)

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    print(z.shape)
    pnum = z.shape[0]
    # pnum= 2000
    # ax.hist(z[:pnum], bins=50)
    # ax.hist([z_n[:pnum], z_p[:pnum]], bins=50, normed=True, color=['red', 'blue'], label=['anomaly', 'ordinally'], histtype='bar', stacked=True)
    ax.hist([z_n[:pnum], z_p[:pnum]], bins=600, normed=False, color=['red', 'blue'], label=['anomaly', 'ordinally'],
            histtype='bar', stacked=True)
    # ax.hist([x1, x2, x3], bins=10, normed=True, color=['red', 'blue', 'green'], label=['x1', 'x2', 'x3'], histtype='bar', stacked=True)
    ax.legend(loc='upper left')
    # fig.show()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)

    '''
    # focus
    i_n_f = np.where(z_n > 2.4)
    z_n_f = np.squeeze(z_n[i_n_f])

    i_p_f = np.where(z_p > 2.4)
    z_p_f = np.squeeze(z_p[i_p_f])

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.hist([z_n_f, z_p_f], bins=500, normed=False, color=['red', 'blue'], label=['anomaly', 'ordinally'], histtype='bar',
            stacked=True)
    # ax.hist([x1, x2, x3], bins=10, normed=True, color=['red', 'blue', 'green'], label=['x1', 'x2', 'x3'], histtype='bar', stacked=True)
    ax.legend(loc='upper left')
    # fig.show()
    plt.savefig('hidden1_l1_1_l2_1000_noise01_epoch50\\hist_bigger_than2.4.png')
    '''


def plot2D(fdir, z, y_test, params, dataset):

    plt.scatter(z[:, 0], z[:, 1], c=y_test, cmap='tab10', s=2)
    plt.colorbar()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)
    #plt.show()

def plot3D(fdir, z, y_test, params, dataset):
    # グラフの枠を作っていく
    fig = plt.figure()
    ax = Axes3D(fig)

    # 軸にラベルを付けたいときは書く
    ax.set_xlabel("z0")
    ax.set_ylabel("z1")
    ax.set_zlabel("z2")

    z = z.transpose()

    i_n = np.where(y_test == 1)
    z0_n = np.squeeze(z[0][i_n])
    z1_n = np.squeeze(z[1][i_n])
    z2_n = np.squeeze(z[2][i_n])

    # print(i_n)
    # print(z0_n.shape)
    i_p = np.where(y_test == 0)
    z0_p = np.squeeze(z[0][i_p])
    z1_p = np.squeeze(z[1][i_p])
    z2_p = np.squeeze(z[2][i_p])


    i_d = np.where(y_test == 2)
    z0_d = np.squeeze(z[0][i_d])
    z1_d = np.squeeze(z[1][i_d])
    z2_d = np.squeeze(z[2][i_d])

    # .plotで描画
    # linestyle='None'にしないと初期値では線が引かれるが、3次元の散布図だと大抵ジャマになる
    #ax.scatter(z0_n, z1_n, z2_n, s=2, color='red', label='anomaly')
    #ax.scatter(z0_p, z1_p, z2_p, s=2, color='blue', label='ordinaly')
    #ax.scatter(z0_p, z1_p, z2_p, s=2, color='blue', label='ordinaly')
    ax.scatter(z0_n, z1_n, z2_n, s=2, color='red')
    ax.scatter(z0_p, z1_p, z2_p, s=2, color='blue')
    ax.scatter(z0_d, z1_d, z2_d, s=2, color='green')

    #ax.legend(loc='upper left')

    #plt.scatter(z[:, 0], z[:, 1], c=y_test, cmap='tab10', s=2)
    #ax.plot(z[:, 0], z[:, 1], z[:, 2], c=y_test, s=2, linestyle='None')
    #ax.scatter(z[:, 0], z[:, 1], z[:, 2], c=y_test, s=2)
    #plt.colorbar()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)
    #plt.show()

def plot3D_pdf(fdir, z, y_test, params, dataset):
    # グラフの枠を作っていく
    fig = plt.figure()
    ax = Axes3D(fig)

    # 軸にラベルを付けたいときは書く
    ax.set_xlabel("z0")
    ax.set_ylabel("z1")
    ax.set_zlabel("z2")

    z = z.transpose()

    ax.scatter(z[0], z[1], z[2], s=2, c=y_test)

    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)
    #plt.show()


from sklearn.preprocessing import MinMaxScaler
mm = MinMaxScaler()

a = np.array([
    [0,0,0,0],
    [1,2,3,4],
    [1,2,3,4],
    [1,2,3,4],
    [2,4,6,8]]
             )
a = mm.fit_transform(a)
b = a[:,0:3]
#print(b)

'''
import numpy as np #適当な配列作るためにNumpy使う
A = np.array([i for i in range(1,1000)]) #自然数の配列
B = np.sin(A) #特に意味のない正弦
C = np.log(B) #特に意味のない自然対数
#備考：Numpyだとnp.log()は自然対数。常用対数はnp.log10()
Z = [A, B, C]
Z = np.array(Z)

ans = np.zeros(A.shape)
ans[:500] = 1
import matplotlib.pyplot as plt

#seabornはimportしておくだけでもmatplotlibのグラフがきれいになる
#デフォルトでスタイルはdarkgridなのでsns.set_style()は書かなくても良い
#import seaborn as sns
#sns.set_style("darkgrid")

#3次元プロットするためのモジュール
from mpl_toolkits.mplot3d import Axes3D

#グラフの枠を作っていく
fig = plt.figure()
ax = Axes3D(fig)

#軸にラベルを付けたいときは書く
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

i_n = np.where(ans == 1)
z0_n = np.squeeze(Z[0][i_n])
z1_n = np.squeeze(Z[1][i_n])
z2_n = np.squeeze(Z[2][i_n])

#print(i_n)
#print(z0_n.shape)
i_p = np.where(ans == 0)
z0_p = np.squeeze(Z[0][i_p])
z1_p = np.squeeze(Z[1][i_p])
z2_p = np.squeeze(Z[2][i_p])
#.plotで描画
#linestyle='None'にしないと初期値では線が引かれるが、3次元の散布図だと大抵ジャマになる
#markerは無難に丸
ax.scatter(z0_n, z1_n, z2_n, s=2, color='red',label='anomaly')
ax.scatter(z0_p, z1_p, z2_p, s=2, color='blue',label='ordinaly')
ax.legend(loc='upper left')
#ax.scatter(A,B,C, s=2, c=ans, cmap='coolwarm')
#ax.colorbar()
#最後に.show()を書いてグラフ表示
plt.show()
'''


