import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import os,sys
import pandas as pd


def plot1D(fdir, z, y_test, params, dataset):
    i_n = np.where(y_test == 1)
    z_n = np.squeeze(z[i_n])
    print(i_n)
    print(z_n.shape)
    i_p = np.where(y_test == 0)
    z_p = np.squeeze(z[i_p])
    print(z_p.shape)

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    print(z.shape)
    pnum = z.shape[0]
    # pnum= 2000
    # ax.hist(z[:pnum], bins=50)
    # ax.hist([z_n[:pnum], z_p[:pnum]], bins=50, normed=True, color=['red', 'blue'], label=['anomaly', 'ordinally'], histtype='bar', stacked=True)
    ax.hist([z_n[:pnum], z_p[:pnum]], bins=600, normed=False, color=['red', 'blue'], label=['anomaly', 'ordinally'],
            histtype='bar', stacked=True)
    # ax.hist([x1, x2, x3], bins=10, normed=True, color=['red', 'blue', 'green'], label=['x1', 'x2', 'x3'], histtype='bar', stacked=True)
    ax.legend(loc='upper left')
    # fig.show()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)

    '''
    # focus
    i_n_f = np.where(z_n > 2.4)
    z_n_f = np.squeeze(z_n[i_n_f])

    i_p_f = np.where(z_p > 2.4)
    z_p_f = np.squeeze(z_p[i_p_f])

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.hist([z_n_f, z_p_f], bins=500, normed=False, color=['red', 'blue'], label=['anomaly', 'ordinally'], histtype='bar',
            stacked=True)
    # ax.hist([x1, x2, x3], bins=10, normed=True, color=['red', 'blue', 'green'], label=['x1', 'x2', 'x3'], histtype='bar', stacked=True)
    ax.legend(loc='upper left')
    # fig.show()
    plt.savefig('hidden1_l1_1_l2_1000_noise01_epoch50\\hist_bigger_than2.4.png')
    '''

def plot1D_pdf(fdir, pdf, params, dataset, pmin = -1000):

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    pnum = pdf.shape[0]
    mean_pred = np.mean(pdf)
    std_pred = np.std(pdf)
    ratio = std_pred / mean_pred

    if pmin == -1000:
        pmin = pdf.min()
    pmin = 0
    pmax = pdf.max()
    if pmax > mean_pred + 3.0 * std_pred :
        pmax = mean_pred + 3.0 * std_pred

    # pnum= 2000
    # ax.hist(z[:pnum], bins=50)
    # ax.hist([z_n[:pnum], z_p[:pnum]], bins=50, normed=True, color=['red', 'blue'], label=['anomaly', 'ordinally'], histtype='bar', stacked=True)
    #ax.hist(pdf[:pnum], bins=600, normed=False, histtype='bar')
    ax.hist(pdf[:pnum], bins=50, normed=False, histtype='bar', range = (pmin, pmax))
    #ax.text(0.3, 30.0, 'r=%.4f' % (ratio))
    fig.text(0.2, 0.8, 'r=%.4f' % (ratio))

    plt.ticklabel_format(style='sci',scilimits=(0,0) )

    # ax.hist([x1, x2, x3], bins=10, normed=True, color=['red', 'blue', 'green'], label=['x1', 'x2', 'x3'], histtype='bar', stacked=True)
    # fig.show()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)

def plot2D(fdir,y_test_, y_pred_, params, dataset='hoge'):

    #plt.scatter(z[:, 0], z[:, 1], c=y_test, cmap='tab10', s=2)
    fig = plt.figure()
    #print(y_test[:20])
    #print(y_pred[:20])

    #print(np.min(y_test))
    s_max = len(y_test_)
    if s_max > 1000:
        s_max = 1000

    s1=pd.Series(y_test_)
    s2=pd.Series(y_pred_)
    res = s1.corr(s2)
    fig.text(0.2, 0.8, 'Corr=%.4f' % (res))

    y_test = y_test_[:s_max]
    y_pred = y_pred_[:s_max]

    x_min = 0
    y_min = 0

    x_max = np.max(y_test) * 1.05
    if x_max > np.mean(y_test)+3.0 * np.std(y_test):
        x_max = np.mean(y_test)+3.0 * np.std(y_test)

    y_max = np.max(y_pred) * 1.05
    if y_max > np.mean(y_pred)+3.0 * np.std(y_pred):
        y_max = np.mean(y_pred)+3.0 * np.std(y_pred)

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.ticklabel_format(style='sci',scilimits=(0,0) )

    #plt.ylim(np.min(y_pred, axis=0), np.max(y_pred, axis=0)*1.1)
    plt.scatter(y_test, y_pred, s=2)
    #plt.colorbar()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    #fig_pass = os.path.join(fdir, 'pdf_%s.png'%(params))
    plt.savefig(fig_pass)
    #plt.show()

def plot2D_with_val(fdir,y_test, y_pred, params, dataset='hoge'):

    #plt.scatter(z[:, 0], z[:, 1], c=y_test, cmap='tab10', s=2)
    fig=plt.figure()
    #print(y_test[:20])
    #print(y_pred[:20])

    #print(np.min(y_test))
    mean_pred = np.mean(y_pred)
    std_pred = np.std(y_pred)
    ratio = std_pred / mean_pred

    plt.xlim(np.min(y_test, axis=0), np.max(y_test, axis=0))
    plt.ylim(np.min(y_pred, axis=0), np.max(y_pred, axis=0)*1.1)
    plt.scatter(y_test, y_pred, s=2)
    fig.text(0.2, 0.8, 'r=%.4f' % (ratio))
    #plt.colorbar()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))

    #fig_pass = os.path.join(fdir, 'pdf_%s.png'%(params))
    plt.savefig(fig_pass)
    #plt.show()


def plot3D(fdir, z, y_test, params, dataset):
    # グラフの枠を作っていく
    fig = plt.figure()
    ax = Axes3D(fig)

    # 軸にラベルを付けたいときは書く
    ax.set_xlabel("z0")
    ax.set_ylabel("z1")
    ax.set_zlabel("z2")

    z = z.transpose()

    # .plotで描画
    # linestyle='None'にしないと初期値では線が引かれるが、3次元の散布図だと大抵ジャマになる
    ax.scatter(z[0], z[1], z[2], s=2, c=y_test)

    ax.legend(loc='upper left')

    #plt.scatter(z[:, 0], z[:, 1], c=y_test, cmap='tab10', s=2)
    #ax.plot(z[:, 0], z[:, 1], z[:, 2], c=y_test, s=2, linestyle='None')
    #ax.scatter(z[:, 0], z[:, 1], z[:, 2], c=y_test, s=2)
    #plt.colorbar()
    fig_pass = os.path.join(fdir, 'z_plot_%s_%s.png'%(dataset, params))
    plt.savefig(fig_pass)
    #plt.show()

from sklearn.preprocessing import MinMaxScaler
mm = MinMaxScaler()

a = np.array([
    [0,0,0,0],
    [1,2,3,4],
    [1,2,3,4],
    [1,2,3,4],
    [2,4,6,8]]
             )
a = mm.fit_transform(a)
b = a[:,0:3]
#print(b)

'''
import numpy as np #適当な配列作るためにNumpy使う
A = np.array([i for i in range(1,1000)]) #自然数の配列
B = np.sin(A) #特に意味のない正弦
C = np.log(B) #特に意味のない自然対数
#備考：Numpyだとnp.log()は自然対数。常用対数はnp.log10()
Z = [A, B, C]
Z = np.array(Z)

ans = np.zeros(A.shape)
ans[:500] = 1
import matplotlib.pyplot as plt

#seabornはimportしておくだけでもmatplotlibのグラフがきれいになる
#デフォルトでスタイルはdarkgridなのでsns.set_style()は書かなくても良い
#import seaborn as sns
#sns.set_style("darkgrid")

#3次元プロットするためのモジュール
from mpl_toolkits.mplot3d import Axes3D

#グラフの枠を作っていく
fig = plt.figure()
ax = Axes3D(fig)

#軸にラベルを付けたいときは書く
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

i_n = np.where(ans == 1)
z0_n = np.squeeze(Z[0][i_n])
z1_n = np.squeeze(Z[1][i_n])
z2_n = np.squeeze(Z[2][i_n])

#print(i_n)
#print(z0_n.shape)
i_p = np.where(ans == 0)
z0_p = np.squeeze(Z[0][i_p])
z1_p = np.squeeze(Z[1][i_p])
z2_p = np.squeeze(Z[2][i_p])
#.plotで描画
#linestyle='None'にしないと初期値では線が引かれるが、3次元の散布図だと大抵ジャマになる
#markerは無難に丸
ax.scatter(z0_n, z1_n, z2_n, s=2, color='red',label='anomaly')
ax.scatter(z0_p, z1_p, z2_p, s=2, color='blue',label='ordinaly')
ax.legend(loc='upper left')
#ax.scatter(A,B,C, s=2, c=ans, cmap='coolwarm')
#ax.colorbar()
#最後に.show()を書いてグラフ表示
plt.show()
'''


