import numpy as np
import time
from quantile_reg import generate_data_sym_model, quantile_reg_gd_solver, analytical_coverage
from quantile_eq import quantile_fixed_point_iter
from scipy.special import ndtri

sigma = 0.5


mode = "small"

if mode == "1":
    kappa_vec = [0.05, 0.1, 0.2, 0.4]
    alpha_vec = np.linspace(0.5, 0.98, 25)
    coverages = np.zeros((len(kappa_vec), len(alpha_vec), n_rep))
    fn = "./data/coverage_mode_1.npy"
elif mode == "2":
    kappa_vec = np.linspace(0.02, 0.5, 25)
    alpha_vec = [0.8, 0.9, 0.95]
    coverages = np.zeros((len(kappa_vec), len(alpha_vec), n_rep))
    fn = "./data/coverage_mode_2.npy"
elif mode == "2_analytical":
    kappa_vec = np.linspace(0.02, 0.5, 25)
    alpha_vec = [0.8, 0.9, 0.95]
    fn = "./data/coverage_mode_2_analytical.npy"
elif mode == "small":
    kappa = 0.5
    alpha = 0.9

if mode == "small":
    np.random.seed(42)
    d, n = 100, 200
    eta, maxiter = 0.01, 50000
    decay_iters, decay_factor = [25000], 0.1
    X, y, w_star = generate_data_sym_model(n, d, sigma=sigma)
    X_test, y_test, _ = generate_data_sym_model(n, d, sigma=sigma, w_star=w_star)
    theta_hat, _, _ = quantile_reg_gd_solver(
        X, y, alpha=alpha,
        eta=eta, maxiter=maxiter,
        verbose=True,
        decay_iters=decay_iters, decay_factor=decay_factor
    )
    coverage = analytical_coverage(theta_hi=theta_hat, w_star=w_star, sigma_z=sigma)
    print(f"coverage={coverage:.6f}")
    v = theta_hat[:-1] - w_star[:-1]
    v /= np.linalg.norm(v)
    X_proj = X_test[:, :-1].dot(v).squeeze()
    X_grid = np.linspace(-3.0, 3.0, 601)
    w_star_coeff = np.dot(w_star[:-1], v)
    theta_hat_coeff = np.dot(theta_hat[:-1], v)
    f_hat = X_grid * theta_hat_coeff + theta_hat[-1]
    f_star = X_grid * w_star_coeff + sigma * ndtri(alpha)
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set_facecolor('whitesmoke')
    ax.set_axisbelow(True)
    ax.tick_params(direction='in')
    plt.grid(ls='dotted')

    plt.scatter(X_proj, y, label="test data")
    plt.fill_between(X_grid, np.min(y), f_hat, label=r"$\hat{f}(x)=\hat{w}^\top x + \hat{b}$", alpha=0.5)
    plt.fill_between(X_grid, np.min(y), f_star, label=r"$q^\star_\alpha(x)=w_\star^\top x + z_\alpha$", alpha=0.5)
    plt.xlabel(r"$x$ (1d proj)")
    plt.ylabel(r"$y$")
    plt.legend(fontsize="small")
    plt.tight_layout()
    plt.savefig("./Figures/illustration.pdf")
    plt.close()

elif mode in ["1", "2"]:
    d = 100
    eta, maxiter = 0.01, 50000
    decay_iters, decay_factor = [25000], 0.1
    n_rep = 5
    np.random.seed(42)
    for i_kappa in range(len(kappa_vec)):
        kappa = kappa_vec[i_kappa]
        for i_alpha in range(len(alpha_vec)):
            alpha = alpha_vec[i_alpha]
            n = int(d / kappa)
            for i_rep in range(n_rep):
                curr = time.time()
                X, y, w_star = generate_data_sym_model(n, d, sigma=sigma)
                theta_hat, _, _ = quantile_reg_gd_solver(
                    X, y, alpha=alpha,
                    eta=eta, maxiter=maxiter,
                    verbose=True,
                    decay_iters=decay_iters, decay_factor=decay_factor
                )
                coverage = analytical_coverage(theta_hi=theta_hat, w_star=w_star, sigma_z=sigma)
                print(f"kappa={kappa:.2f}, alpha={alpha:.2f}, rep={i_rep}, coverage={coverage:.4f}")
                print(f"elapsed time={time.time() - curr:.4f}s")
                coverages[i_kappa, i_alpha, i_rep] = coverage
            np.save(fn, coverages)
elif mode == "2_analytical":
    maxiter = 200
    rel_tol = 1e-4
    coverages = np.load(fn)
    for i_alpha in range(len(alpha_vec)):
        if i_alpha in [0, 1]:
            continue
        alpha = alpha_vec[i_alpha]
        tau, lam, b = 0.7, 7.0, sigma * ndtri(alpha)
        state = (tau, lam, b)
        for i_kappa in range(len(kappa_vec)-1, -1, -1):
            curr = time.time()
            kappa = kappa_vec[i_kappa]
            delta = 1./kappa
            for i in range(maxiter):
                state = quantile_fixed_point_iter(
                    state, a=alpha, delta=delta, sigma_z=sigma,
                    output_eq_vals=False
                )
                tau_new, lam_new, b_new = state
                tau_rel, lam_rel, b_rel = np.abs(tau_new - tau) / tau, np.abs(lam_new - lam) / lam, np.abs(b_new - b) / b
                if np.max([tau_rel, lam_rel, b_rel]) < 1e-4:
                    break
                tau, lam, b = tau_new, lam_new, b_new
                state = (tau, lam, b)
                if (i+1) % 10 == 0:
                    print(f"Iter [{i+1}/{maxiter}], tau={tau:.6f}, lam={lam:.6f}, b={b:.6f}")
                    print(f"tau_rel={tau_rel:.6f}, lam_rel={lam_rel:.6f}, b_rel={b_rel:.6f}")
            coverage = analytical_coverage(w_err=tau, b=b, sigma_z=sigma)
            coverages[i_kappa, i_alpha] = coverage
            print(f"kappa={kappa:.2f}, alpha={alpha:.2f}, theoretical coverage={coverage:.4f}")
            print(f"elapsed time={time.time() - curr:.4f}s")
            np.save(fn, coverages)








# print("File saved.")
