import numpy as np
import pdb
import scipy.integrate as integrate
from scipy.stats import norm
from scipy.special import ndtri

from quantile_eq import gaussian_pdf

def l_alpha(r, alpha=0.05):
    return alpha * r * (r >= 0) - (1-alpha) * r * (r < 0)


def l_alpha_p(r, alpha=0.05):
    return alpha * (r >= 0) - (1-alpha) * (r < 0)


def generate_data_box_model(n=2000, d=100, sigma_v=0.5,
                            w_star=None, v_star=None):
    """
    Generate data from model y = (w_\star + \eps v_\star)^T x
    """
    X = np.random.rand(n, d)
    if w_star is None and v_star is None:
        w_star = np.random.randn(d)
        w_star /= np.linalg.norm(w_star)
        v_star = (sigma_v / d) * np.random.rand(d)
    z = np.random.randn(n)
    y = np.dot(X, w_star) + np.dot(X, v_star) * z
    return X, y, w_star, v_star


def generate_data_sym_model(n=2000, d=100, sigma=0.5,
                            w_star=None,
                            bias=True):
    """
    Generate data from model y = w_\star^T x + sigma * z
    If bias=True, X has shape nx(d+1), w_star has shape (d+1)
    """
    X = np.random.randn(n, d)
    if bias:
        X = np.hstack((X, np.ones((n, 1))))
    if w_star is None:
        w_star = np.random.randn(d)
        w_star /= np.linalg.norm(w_star)
        if bias:
            w_star = np.concatenate((w_star, [0]))
    z = np.random.randn(n)
    y = np.dot(X, w_star) + sigma * z
    return X, y, w_star


def quantile_reg_gd_solver(X, y, alpha=0.05,
                           eta=0.01, maxiter=10000,
                           avg=False,
                           X_val=None, y_val=None,
                           fix_b_val=None,
                           verbose=False,
                           decay_iters=None,
                           decay_factor=0.1):
    n, d = X.shape
    theta = np.random.randn(d)
    if fix_b_val is not None:
        theta[-1] = fix_b_val
    theta_avg = np.copy(theta)
    train_losses, val_losses = list(), list()
    for i in range(maxiter):
        r = y - np.dot(X, theta)
        if avg:
            r_avg = y - np.dot(X, theta_avg)
            train_loss = np.mean(l_alpha(r_avg, alpha=alpha))
        else:
            train_loss = np.mean(l_alpha(r, alpha=alpha))
        train_losses.append(train_loss)
        if X_val is not None:
            if avg:
                r_val = y_val - np.dot(X_val, theta_avg)
            else:
                r_val = y_val - np.dot(X_val, theta)
            val_loss = np.mean(l_alpha(r_val, alpha=alpha))
            val_losses.append(val_loss)
        signs = alpha * (r >= 0) - (1-alpha) * (r < 0)
        grad_theta = -np.dot(X.T, signs) / n
        theta = theta - eta * grad_theta
        if fix_b_val is not None:
            theta[-1] = fix_b_val
        theta_avg = i/(i+1) * theta_avg + 1/(i+1) * theta
        if verbose and (i+1) % 10000 == 0:
            print(f"Iter [{i+1}/{maxiter}]: train_loss={train_loss:.6f}")
        if decay_iters is not None and (i+1) in decay_iters:
            eta *= decay_factor
    if avg:
        return theta_avg, train_losses, val_losses
    else:
        return theta, train_losses, val_losses


def coverage(X, y, theta_lo=None, theta_hi=None):
    if theta_hi is None:
        return np.mean(y >= np.dot(X, theta_lo))
    elif theta_lo is None:
        return np.mean(y <= np.dot(X, theta_hi))
    return np.mean((y >= np.dot(X, theta_lo)) * (y <= np.dot(X, theta_hi)))


def analytical_coverage(theta_hi=None, w_star=None, sigma_z=0.5,
                        LimInt=6.0,
                        w_err=None, b=None):
    if w_err is None:
        w_err = np.linalg.norm(theta_hi[:-1] - w_star[:-1])
    if b is None:
        b = theta_hi[-1]
    # import pdb; pdb.set_trace()
    coverage = integrate.quad(
        lambda G: norm.cdf((w_err * G + b) / sigma_z) * gaussian_pdf(G),
        -LimInt, LimInt,
    )[0]
    return coverage


if __name__ == "__main__":
    n, d = 4000, 200
    sigma = 0.5
    avg = False
    fix_b_val = None
    decay_iters = [25000]
    decay_factor = 0.1
    model = 'sym'
    np.random.seed(100)
    if model == 'box':
        X, y, w_star, v_star = generate_data_box_model(n=n, d=d, sigma_v=sigma)
        X_val, y_val, _, _ = generate_data_box_model(n=n, d=d, w_star=w_star, v_star=v_star)
    elif model == 'sym':
        X, y, w_star = generate_data_sym_model(n=n, d=d, sigma=sigma, bias=True)
        X_val, y_val, _ = generate_data_sym_model(n=n, d=d, sigma=sigma, bias=True, w_star=w_star)
    alpha_lo, alpha_hi = 0.05, 0.9
    eta, maxiter = 0.01, 50000
    theta_hi, train_losses_hi, val_losses_hi = quantile_reg_gd_solver(
        X, y, alpha=alpha_hi, eta=eta, maxiter=maxiter,
        avg=avg,
        X_val=X_val, y_val=y_val,
        verbose=True,
        fix_b_val=fix_b_val,
        decay_iters=decay_iters, decay_factor=decay_factor,
    )
    print(f"Nominal quantile levels: alpha_lo={alpha_lo:.2f}, alpha_hi={alpha_hi:.2f}")
    print(f"One-sided coverage of learned theta_hi on test set: "
          f"{coverage(X_val, y_val, theta_hi=theta_hi)}")
    print(f"kappa={d/n:.6f}")
    print(f"Analytical coverage of learned theta_hi={analytical_coverage(theta_hi, w_star, sigma_z=sigma):.6f}")
    print(f"Formula (1st order at kappa=0): alpha - (1-alpha/2)*kappa={alpha_hi - (1-alpha_hi/2) * (d/n):.6f}")
