import numpy as np
import scipy.integrate as integrate
from scipy.stats import norm
from scipy.special import ndtri


def loss(u, b, a):
    return -(1-a) * (u-b) * (u <= b) + a * (u-b) * (u > b)


def loss_p(u, b, a):
    return -(1-a) * (u <= b) + a * (u > b)


def loss_p_b(u, b, a):
    return -loss_p(u, b, a)


def prox_op(u, b, kappa, a):
    """
    Exact prox operator
    for the pinball loss l(t) := l_a(t - b)
    """
    return np.minimum(u + (1-a) * kappa, np.maximum(b, u - a * kappa))


def e_lb(u, b, kappa, a):
    """
    Value of prox problem
    """
    v = prox_op(u, b, kappa, a)
    return (u - v) ** 2 / (2*kappa) + loss(v, b, a)


def e_lb_p(u, b, kappa, a):
    return (u - prox_op(u, b, kappa, a)) / kappa


def e_lb_p_b(u, b, kappa, a):
    return -e_lb_p(u, b, kappa, a)


def lb_p_prox(u, b, kappa, a):
    return np.minimum(a, np.maximum((u - b) / kappa, -(1-a)))


def gaussian_pdf(t):
    return np.exp(-t**2 / 2) / np.sqrt(2*np.pi)


def quantile_fixed_point_iter(state, a=0.95, delta=20, sigma_z=1.0,
                              LimInt=6.0,
                              gd_iter=5, gd_eta=0.1,
                              fix_b_val=None,
                              grad_delta=None,
                              output_eq_vals=True):
    """
    Fixed point iteration for quantile regression
    """
    alpha, kappa, b = state
    if fix_b_val is not None:
        b = fix_b_val
    exp_1, _ = integrate.dblquad(
        lambda g, z: e_lb_p(alpha*g+sigma_z*z, b, kappa, a)**2 * gaussian_pdf(g) * gaussian_pdf(z),
        -LimInt, LimInt, -LimInt, LimInt,
    )
    exp_2, _ = integrate.dblquad(
        lambda g, z: e_lb_p(alpha*g+sigma_z*z, b, kappa, a) * g * gaussian_pdf(g) * gaussian_pdf(z),
        -LimInt, LimInt, -LimInt, LimInt,
    )
    if grad_delta is not None:
        exp_3, _ = integrate.dblquad(
            lambda g, z: e_lb_p_b(alpha*g+sigma_z*z, b, kappa, a) * gaussian_pdf(g) * gaussian_pdf(z),
            -LimInt, LimInt, -LimInt, LimInt,
        )
        exp_4, _ = integrate.dblquad(
            lambda g, z: loss_p_b(sigma_z * z, b, a) * gaussian_pdf(g) * gaussian_pdf(z),
            -LimInt, LimInt, -LimInt, LimInt,
        )
        b_plus, b_minus = b + grad_delta, b - grad_delta
        exp_5_plus, _ = integrate.dblquad(
            lambda g, z: (e_lb(alpha*g+sigma_z*z, b_plus, kappa, a) - loss(sigma_z*z, b_plus, a)) * gaussian_pdf(g) * gaussian_pdf(z),
            -LimInt, LimInt, -LimInt, LimInt,
        )
        exp_5_minus, _ = integrate.dblquad(
            lambda g, z: (e_lb(alpha*g+sigma_z*z, b_minus, kappa, a) - loss(sigma_z*z, b_minus, a)) * gaussian_pdf(g) * gaussian_pdf(z),
            -LimInt, LimInt, -LimInt, LimInt,
        )
        exp_5, _ = integrate.dblquad(
            lambda g, z: (e_lb(alpha*g+sigma_z*z, b, kappa, a)) * gaussian_pdf(g) * gaussian_pdf(z),
            -LimInt, LimInt, -LimInt, LimInt,
        )
    # compute equation values
    if output_eq_vals:
        eq_vals = alpha ** 2 / delta - kappa ** 2 * exp_1, alpha / delta - kappa * exp_2
    # perform fixed point update
    alpha_new = kappa * np.sqrt(delta * exp_1)
    kappa_new = alpha / delta / exp_2

    if grad_delta is not None:
        analytical_grad = exp_3 - exp_4
        finite_diff_grad = (exp_5_plus - exp_5_minus) / (2 * grad_delta)
        loss_val = delta * exp_5 - alpha**2 / (2*kappa)
        print(f"Analytical grad={analytical_grad:.6f}, finite_diff_grad={finite_diff_grad:.6f}, loss={loss_val:.6f}")
        print(f"exp_3={exp_3:.6f}")
    b_new = b
    for i in range(gd_iter):
        exp_3, _ = integrate.quad(
            lambda w: e_lb_p(np.sqrt(alpha**2 + sigma_z**2)*w, b_new, kappa, a) * gaussian_pdf(w),
            -LimInt, LimInt,
        )
        if output_eq_vals and i == 0:
            # record first exp_3
            eq_vals = eq_vals[0], eq_vals[1], exp_3
        b_new = b_new + gd_eta * exp_3
        if i == gd_iter - 1:
            print(f"grad_b={exp_3:.6f}")
    if output_eq_vals:
        return (alpha_new, kappa_new, b_new), eq_vals
    return alpha_new, kappa_new, b_new


def transform_from_bar(alpha_bar, kappa_bar, b_bar, delta,
                       a=0.95, sigma_z=0.5):
    z_alpha = sigma_z * ndtri(a)
    state = (alpha_bar / np.sqrt(delta), kappa_bar / delta, z_alpha + b_bar / delta)
    return state


def numerical_grad(alpha_bar, kappa_bar, b_bar, delta,
                   a=0.95, sigma_z=0.5, eps=1e-6):
    grads = np.zeros((3, 4))
    eps_kappa = 0.01 / delta
    bar = np.array([alpha_bar, kappa_bar, b_bar])
    scaling = [delta**2, delta**1.5, delta**1]
    for i in range(3):
        bar_plus, bar_minus = np.copy(bar), np.copy(bar)
        bar_plus[i] += eps
        bar_minus[i] -= eps
        state_plus = transform_from_bar(*bar_plus, delta)
        state_minus = transform_from_bar(*bar_minus, delta)
        rescaled_eq_vals_plus = np.array(quantile_fixed_point_iter(
            state_plus, a=a, delta=delta, sigma_z=sigma_z
        )[1]) * scaling
        rescaled_eq_vals_minus = np.array(quantile_fixed_point_iter(
            state_minus, a=a, delta=delta, sigma_z=sigma_z
        )[1]) * scaling
        grads[:, i] = (rescaled_eq_vals_plus - rescaled_eq_vals_minus) / (2 * eps)
    # Compute numerical gradient with respect to kappa
    delta_plus, delta_minus = 1. / (1./delta + eps_kappa), 1. / (1./delta - eps_kappa)
    state_plus = transform_from_bar(*bar, delta_plus)
    state_minus = transform_from_bar(*bar, delta_minus)
    scaling_plus = [delta_plus ** 2, delta_plus ** 1.5, delta_plus ** 1]
    scaling_minus = [delta_minus ** 2, delta_minus ** 1.5, delta_minus ** 1]
    rescaled_eq_vals_plus = np.array(quantile_fixed_point_iter(
        state_plus, a=a, delta=1. / (1./delta + eps_kappa), sigma_z=sigma_z
    )[1]) * scaling_plus
    rescaled_eq_vals_minus = np.array(quantile_fixed_point_iter(
        state_minus, a=a, delta=1. / (1./delta - eps_kappa), sigma_z=sigma_z
    )[1]) * scaling_minus
    grads[:, 3] = (rescaled_eq_vals_plus - rescaled_eq_vals_minus) / (2 * eps_kappa)
    return grads


if False:
    alpha_bar, kappa_bar, b_bar = 1.0, 4.0, -0.4
    delta = 5000.0
    sigma_z = 0.5
    a = 0.95
    LimInt = 6.0
    z_alpha = sigma_z * ndtri(a)
    grads = numerical_grad(alpha_bar, kappa_bar, b_bar, delta,
                           eps=1e-6)
    print(grads)
    grad_12 = -2*kappa_bar * integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a)**2 * gaussian_pdf(w), -LimInt, LimInt,
    )[0]
    grad_21 = 1 - kappa_bar * integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a) * w * gaussian_pdf(w), -LimInt, LimInt,
    )[0] / sigma_z
    grad_22 = -alpha_bar * integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a) * w * gaussian_pdf(w), -LimInt, LimInt,
    )[0] / sigma_z
    grad_31 = alpha_bar * integrate.quad(
        lambda w: (loss_p(sigma_z*w, z_alpha, a) * w**2/sigma_z**2 - loss(sigma_z*w, z_alpha, a) * w/sigma_z**3) * gaussian_pdf(w),
        -LimInt, LimInt,
    )[0]
    grad_32 = -integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a)**2 * w/(2*sigma_z) * gaussian_pdf(w),
        -LimInt, LimInt,
    )[0]
    grad_33 = -integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a) * w/sigma_z * gaussian_pdf(w), -LimInt, LimInt,
    )[0]
    print(f"Analytical grad_12={grad_12:.6f}")
    print(f"Analytical grad_21={grad_21:.6f}")
    print(f"Analytical grad_22={grad_22:.6f}")
    print(f"Analytical grad_31={grad_31:.6f}")
    print(f"Analytical grad_32={grad_32:.6f}")
    print(f"Analytical grad_33={grad_33:.6f}")
    import sys; sys.exit()


def bar_checker(sigma_z=0.5, a=0.95, LimInt=6.0):
    z_alpha = sigma_z * ndtri(a)
    exp_ppp = integrate.quad(
        lambda w: (loss_p(sigma_z*w, z_alpha, a) * w**2/sigma_z**2 - loss(sigma_z*w, z_alpha, a) * w/sigma_z**3) * gaussian_pdf(w),
        -LimInt, LimInt,
    )[0]
    exp_pp = integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a) * w/sigma_z * gaussian_pdf(w), -LimInt, LimInt,
    )[0]
    exp_p_sq = a*(1-a)
    exp_p_pp = integrate.quad(
        lambda w: loss_p(sigma_z*w, z_alpha, a)**2 * w/(2*sigma_z) * gaussian_pdf(w),
        -LimInt, LimInt,
    )[0]
    lambda_bar = 1. / exp_pp
    tau_bar = np.sqrt(exp_p_sq) / exp_pp
    b_bar = (tau_bar**2/2 * exp_ppp - lambda_bar * exp_p_pp) / exp_pp
    print(f"quantity_1 = {tau_bar**2/2 * exp_ppp:.6f}, quantity_2={lambda_bar * exp_p_pp:.6f}")
    return tau_bar, lambda_bar, b_bar


def b_bar_gaussian(sigma_z=0.5, a=0.95):
    phi = gaussian_pdf(ndtri(a)) / sigma_z
    return (-(2*a-1)*phi + a*(1-a)*ndtri(a)/sigma_z) / (2*phi**2)


def gaussian_pdf_derivatives():
    alpha = np.linspace(0.51, 0.99, 49)
    b = ndtri(alpha)
    phi = np.exp(-b**2 / 2) / np.sqrt(2 * np.pi)
    phi_p = -phi * b
    left = -phi_p / phi**2
    right = (2*alpha - 1) / alpha / (1-alpha)
    import matplotlib.pyplot as plt
    plt.plot(alpha, right / left - 1, label="ratio-1")
    plt.legend()
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    delta, sigma_z = 100.0, 1.0
    a = 0.93
    z_a = sigma_z * ndtri(a)
    alpha, kappa, b = 0.095508, 0.037921, 0.735105 * 2
    fix_b_val = None
    grad_delta = None
    print(f"1/kappa={delta:.2f}")
    tau_bar, lambda_bar, b_bar = bar_checker(sigma_z=sigma_z, a=a)
    print(f"Predicted tau_bar={tau_bar:.6f} lambda_bar={lambda_bar:.6f} b_bar={b_bar:.6f}")
    b_bar_new = b_bar_gaussian(sigma_z=sigma_z, a=a)
    print(f"Predicted b_bar using gaussian densities={b_bar_new:.6f}")
    gaussian_pdf_derivatives()


    show_eq_vals = False
    maxiter = 100
    for i in range(maxiter):
        state = (alpha, kappa, b)
        (alpha, kappa, b), eq_vals = quantile_fixed_point_iter(
            state, a=a, delta=delta, sigma_z=sigma_z,
            fix_b_val=fix_b_val, grad_delta=grad_delta
        )
        print(f"tau={alpha:.6f}, lambda={kappa:.6f}, b={b:.6f}")
        print(f"tau/sqrt(kappa)={alpha*np.sqrt(delta):.6f}, lambda/kappa={kappa*delta:.6f}, "
              f"(b-z_a)/kappa={(b - z_a)*delta:.6f}")
        if show_eq_vals:
            print(f"Equation values: eq_1={eq_vals[0]:.8f}, eq_2={eq_vals[1]:.8f}, eq_3={eq_vals[2]:.8f}")
            print(f"Rescaled Eq values: "
                  f"eq_1/kappa^2={eq_vals[0]*delta**2:.8f}, "
                  f"eq_2/kappa^1.5={eq_vals[1]*delta**(1.5):.8f}, "
                  f"eq_3/kappa={eq_vals[2]*delta**1:.8f}")
