import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas
import mpl_toolkits.axisartist as axisartist

path = 'logs/experiments_logs/images/'
def get_data(u=1, sig=0.1, min_=0, max_=2):
    #x = np.linspace(min_, max_, 50) 
    x = np.linspace(u - 3 * sig, u + 3 * sig, 1000)
    y = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2*math.pi)*sig)
    return x, y

def get_pd(x_name, y_name, x, y):
    return {x_name:x, y_name:y}
# timestep: 

def plot_eta_weight():
    plt.style.use(['ggplot','grid'])
    steps = [999, 5999, 10999, 20999]
    fig_name = 'only_eta.pdf'
    fig_path = path + fig_name
    plt.grid(True)
    # eta1
    std1 = [0.556,0.922,0.441,0.767]
    min1 = [0.257,0.017,0.207,0.106]
    max1 = [2.906,3.713,2.359,3.906]
    eta1_data = [ get_data(1, std1[i], min1[i], max1[i]) for i in range(len(steps))]
    # eta3
    std3 = [0.180,0.487,0.383,0.594]
    min3 = [0.645,0.175,0.316,0.322]
    max3 = [1.484,1.961,2.113,2.705]
    eta3_data = [ get_data(1, std3[i], min3[i], max3[i]) for i in range(len(steps))]
    # eta10
    std10 = [0.053,0.206,0.116,0.05]
    min10 = [0.884,0.589,0.861,0.953]
    max10 = [1.129,1.297,1.318,1.174]
    eta10_data = [ get_data(1, std10[i], min10[i], max10[i]) for i in range(len(steps))]
    # eta50
    # std50 = [0.1, 0.027, 0.006, 0.001]
    # min50 = [0.974,0.964,0.997, 0.998]
    # max50 = [1.025,1.065,1.034, 1.004]
    # eta50_data = [ get_data(1, std50[i], min50[i], max50[i]) for i in range(len(steps))]
    x_name = 'Weight'; y_name = 'Density'
    names = [r'$\tau^*=1$', r'$\tau^*=3$',r'$\tau^*=10$',r'$\tau^*=50$']
    fig = plt.figure(figsize=(16, 4))
    for i in range(len(steps)):
        num = 141 + i
        ax = axisartist.Subplot(fig, num)  
        fig.add_axes(ax)

        ax.plot(eta1_data[i][0], eta1_data[i][1], label=names[0])
        ax.plot(eta3_data[i][0], eta3_data[i][1], label=names[1])
        ax.plot(eta10_data[i][0], eta10_data[i][1], label=names[2])
        #ax.plot(eta50_data[i][0], eta50_data[i][1], label=names[3])
        if i == 0:
            plt.ylabel(y_name,fontsize=34)

        plt.legend()
    plt.savefig(fig_path,dpi=300,format='pdf',pad_inches=0.0, bbox_inches = 'tight')
        # sns.lineplot(eta1_data[0][0], eta1_data[0][1])
        # fig = sns.lineplot(eta10_data[0][0], eta10_data[0][1])
        # scatter_fig = fig.get_figure()
    #scatter_fig.savefig(fig_path, format='pdf',dpi = 300)
def plot_eta_queue10_weight():
    plt.style.use(['light','grid'])
    steps = [999, 5999, 10999, 20999]
    fig_name = 'eta_queue10.pdf'
    fig_path = path + fig_name
    plt.grid(True)
    # eta5 k=10
    eta5k10_mean = [0.993, 1.010, 0.991, 0.984]
    eta5k10_std = [0.102, 0.337, 0.211, 0.220]
    eta5k10_data = [ get_data(eta5k10_mean[i], eta5k10_std[i]) for i in range(len(steps))]

    # eta5 k100
    eta5k100_mean = [0.999, 0.983, 0.982, 0.941]
    eta5k100_std = [0.102, 0.332, 0.178, 0.192]
    eta5k100_data = [ get_data(eta5k100_mean[i], eta5k100_std[i]) for i in range(len(steps))]

    # eta10 k10
    eta10k10_mean = [1.001, 0.992, 0.993, 0.995]
    eta10k10_std = [0.052, 0.199, 0.113, 0.101]
    eta10k10_data = [ get_data(eta10k10_mean[i], eta10k10_std[i]) for i in range(len(steps))]

    # eta10 k100
    eta10k100_mean = [1.003, 0.987, 0.989, 0.997]
    eta10k100_std = [0.052, 0.198, 0.100, 0.052]
    eta10k100_data = [ get_data(eta10k100_mean[i], eta10k100_std[i]) for i in range(len(steps))]

    x_name = 'Weight'; y_name = 'Density'
    names = [r'$\tau^*=5; |\mathcal{M}|=800$',r'$\tau^*=5; |\mathcal{M}|=8000$',r'$\tau^*=10; |\mathcal{M}|=800$', r'$\tau^*=10; |\mathcal{M}|=8000$']
    fig = plt.figure(figsize=(16, 4))
    for i in range(len(steps)):
        num = 141 + i
        ax = axisartist.Subplot(fig, num)  
        fig.add_axes(ax)

        ax.plot(eta5k10_data[i][0], eta5k10_data[i][1], label=names[0])
        ax.plot(eta5k100_data[i][0], eta5k100_data[i][1], label=names[1])
        ax.plot(eta10k10_data[i][0], eta10k10_data[i][1], label=names[2])
        ax.plot(eta10k100_data[i][0], eta10k100_data[i][1], label=names[3])
        
        if i == 0:
            plt.ylabel(y_name,fontsize=34)
        plt.legend()
    plt.savefig(fig_path,dpi=300,format='pdf', pad_inches=0.0, bbox_inches = 'tight')

def funcational_eta():
    plt.style.use(['seaborn-paper','grid'])
    steps = [999, 5999, 10999, 20999]
    fig_name = 'funcational_eta.pdf'
    fig_path = path + fig_name
    plt.grid(True)
    # eta5 k=10
    eta_50_5_3_mean = [1.000, 1.000, 1.000, 0.999]
    eta_50_5_3_std = [0.010, 0.006, 0.011, 0.009]
    eta_50_5_3_data = [ get_data(eta_50_5_3_mean[i], eta_50_5_3_std[i]) for i in range(len(steps))]

    # eta5 k100
    eta_100_5_3_mean = [1.000, 1.000, 1.000, 0.999]
    eta_100_5_3_std = [0.005, 0.003, 0.011, 0.009]
    eta_100_5_3_data = [ get_data(eta_100_5_3_mean[i], eta_100_5_3_std[i]) for i in range(len(steps))]

    # eta [100,20,5] k10
    eta_100_20_3_mean = [1.000, 1.000, 1.000, 0.999]
    eta_100_20_3_std = [0.005, 0.002, 0.003, 0.008]
    eta_100_20_3_data = [ get_data(eta_100_20_3_mean[i], eta_100_20_3_std[i]) for i in range(len(steps))]

    # eta10 k100
    eta100_20_5_mean = [1.000, 1.000, 1.000, 0.999]
    eta100_20_5_std = [0.005, 0.003, 0.003, 0.005]
    eta100_20_5_data = [ get_data(eta100_20_5_mean[i], eta100_20_5_std[i]) for i in range(len(steps))]

    x_name = 'Weight'; y_name = 'Density'
    names = names = [r'$\tau^*=[50,5,3]$',r'$\tau^*=[100,5,3]$',r'$\tau^*=[100,20,3]$', r'$\tau^*=[100,20,5]$']
    fig = plt.figure(figsize=(16, 4))
    for i in range(len(steps)):
        num = 141 + i
        ax = axisartist.Subplot(fig, num)  
        fig.add_axes(ax)

        ax.plot(eta_50_5_3_data[i][0], eta_50_5_3_data[i][1], label=names[0])
        ax.plot(eta_100_5_3_data[i][0], eta_100_5_3_data[i][1], label=names[1])
        ax.plot(eta_100_20_3_data[i][0], eta_100_20_3_data[i][1], label=names[2])
        ax.plot(eta100_20_5_data[i][0], eta100_20_5_data[i][1], label=names[3])

        if i == 0:
            plt.ylabel(y_name,fontsize=34)

        plt.legend()
    plt.savefig(fig_path,dpi=300,format='pdf', pad_inches=0.0, bbox_inches = 'tight')

plot_eta_weight()
plot_eta_queue10_weight()
funcational_eta()