import numpy as np
import matplotlib.pyplot as plt

C = 1e2


def f_den(x, k, p):

    r = 2 * k / C
    t = np.exp(C * x)
    t = t / np.max(t)
    y = np.power(t, r) / np.power(x, 1/2) * np.power(1-x, (p-3)/2)
    n = len(x)
    N = 0.0
    N = np.longdouble(N)
    for i in range(n-1):
        N += y[i] * (x[i+1] - x[i])
    return N, y / N


def f_mean(x, k, p):

    r = 2 * k / C
    t = np.exp(C * x)
    t = t / np.max(t)
    y = np.power(t, r) * np.power(x, 1 / 2) * np.power(1 - x, (p - 3) / 2)
    n = len(x)
    M = 0.0
    M = np.longdouble(M)
    for i in range(n - 1):
        M += y[i] * (x[i + 1] - x[i])
    return M, y / M


def f_tail(x, k, p, xi):

    r = 2 * k / C
    t = np.exp(C * x)
    t = t / np.max(t)
    y = np.power(t, r) / np.power(x, 1 / 2) * np.power(1 - x, (p - 3) / 2)

    n = len(x)
    N = 0.0
    N = np.longdouble(N)
    P = N
    for i in range(n - 1):
        N += y[i] * (x[i + 1] - x[i])
        if x[i+1] <= xi:
            P += y[i] * (x[i + 1] - x[i])
    print(P / N)
    return P / N


def f_tail_xi(x, k, p, xi_list):

    L = len(xi_list)
    P = np.zeros(L)
    for i in range(L):
        P[i] = f_tail(x, k, p, xi_list[i])
    return P


def main():

    p = 100
    lr = 0.1
    wd = 1e-1
    s2 = 1
    k = np.sqrt(p-1) / np.sqrt(2 * lr * wd * s2) / 2

    x1 = np.linspace(0.01, 1, 2000, endpoint=True)
    x2 = np.logspace(-12, -2, 100, endpoint=False)
    x = np.concatenate((x2, x1))
    #print(x)

    au_list = np.power(0.1, np.arange(0.5, 2, 0.05))
    xi_list = np.concatenate((np.power(0.1, np.arange(7, 3, -1)), np.arange(1e-3, 1e-0, 1e-2)))
    L = len(au_list)

    for i in range(L):

        k = np.sqrt(p-1) / 2 / au_list[i] / np.sqrt(s2)
        y = f_tail_xi(x, k, p, xi_list)
        plt.plot(xi_list, y, linewidth=1, color=[1-i/L, 0, i/L], label="au="+str(round(au_list[i],2)))

    plt.ylim(-0.05, 1.05)
    plt.xlim(0, 1)
    plt.xlabel("xi")
    #plt.legend()
    plt.axhline(0, linewidth=0.5, linestyle='--', color='black')
    plt.axhline(1, linewidth=0.5, linestyle='--', color='black')
    plt.savefig('./fig/tail_prob_xi.png', dpi=800)
    plt.show()


if __name__ == '__main__':
    main()
