import numpy as np
import math
from scipy.linalg import eigh
from scipy.sparse import diags
import matplotlib.pyplot as plt
from scipy.stats import t    # <-- import t for your CI
#We evaluate sharpness of Theorem 2.1 by comparing its bound with the actual error and the EY-N bound in the setting that
# A = Discretized Hamiltonian (n is either 500 or 1000), E = Rademacher Noise, p = 10.

#Generate A
def generate_A(n, L=8.0):
    """Build 2√n·H on [-L,L]."""
    h = 2*L/(n + 1)
    i = np.arange(1, n+1)
    x = -L + i*h
    main_diag = 2.0/h**2 + x**2
    off_diag  = -1.0/h**2 * np.ones(n-1)
    H = diags([off_diag, main_diag, off_diag], offsets=[-1,0,1], format='csr').toarray()
    return 8.0 * np.sqrt(n) * H

def best_rank_p_inverse(A_inv, p):
    eigvals, eigvecs = eigh(A_inv)
    idx = np.argsort(eigvals)[::-1][:p]
    return (eigvecs[:, idx] * eigvals[idx]) @ eigvecs[:, idx].T

# Compute the actual error, the r.h.s of Theorem 2.1, The EY-N bound, and the r.h.s of Theorem B.2
def compute_perturbation_and_bound_debug(A, E, p, verbose=False):
    n = A.shape[0]
    A_inv_p       = best_rank_p_inverse(np.linalg.inv(A), p)
    tilde_A_inv_p = best_rank_p_inverse(np.linalg.inv(A+E), p)
    actual_error  = np.linalg.norm(tilde_A_inv_p - A_inv_p, 2)

    # descending eigens
    eigvals, eigvecs = eigh(A)
    eigvals = np.sort(eigvals)[::-1]
    eigvecs = eigvecs[:, ::-1]

    λ_n     = eigvals[-1]
    δ_np    = eigvals[n-p-1] - eigvals[n-p] if n-p-1>=0 else eigvals[-1]
    δ       = min(λ_n, δ_np)
    λ_np1   = eigvals[n-p]
    # compute r of Theorem B.2
    r = next((r_try for r_try in range(1,n)
              if eigvals[n-r_try] >= 2*λ_np1), n)
    λ_nr    = eigvals[n-r]
    U_bot   = eigvecs[:, -r:]
    # compute the interaction parameter x of Theorem B.2
    x       = np.max(np.abs(U_bot.T @ E @ U_bot))
    norm_E  = np.linalg.norm(E, 2)
# compute the bounds: bound4= the r.h.s of Theorem 2.1, bound3= the r.h.s of Theorem B.2, boundC=EY-N bound
    bound4       = 4*norm_E/(λ_n**2) + 5*norm_E/(δ_np*(eigvals[n-p-1] + eigvals[n-p]))
    bound3       = (25*norm_E)/(λ_n*λ_nr) + 4*(r**2)*x/(λ_n**2) + 4*(r**2)*x/(δ_np*(eigvals[n-p-1]+eigvals[n-p]))
    upper_bound  =  bound4/(2*np.sqrt(n))
    boundC       = (8*norm_E/(3*λ_n**2) + 2/eigvals[n-p-1])/(2*np.sqrt(n))

    return {
        "actual_error": actual_error,
        "upper_bound":  upper_bound,
        "boundC":       boundC,
        "ratio":        actual_error/upper_bound,
        "r":            r,
        "x":            x,
        "norm_E":       norm_E
    }

def batch_run_with_debug(A, p=None, noise_levels=None, trials=None):
    n = A.shape[0]
    p = p or math.floor(math.log(n))
    noise_levels = noise_levels or np.logspace(-4, -1, 10) #Noise level
    trials = trials or 10

    # storage
    avg_act = []; std_act = []
    avg_bnd = []; std_bnd = []
    avg_bc  = []; std_bc  = []
    avg_rat = []; std_rat = []
    avg_xs  = []; std_xs  = []

    for η in noise_levels:
        acts, bnds, bcs, rats, xs = [], [], [], [], []
        for _ in range(trials):
            upper = np.random.choice([-1, 1], size=(n, n)) #Generate Rad noise
            E = np.triu(upper, k=0).astype(np.float64) # keep upper triangle including the main diagonal, and cast to float64
            E = E + E.T - np.diag(np.diag(E)) # Make the matrix symmetric
            E *= η

            res = compute_perturbation_and_bound_debug(A, E, p)
            acts.append(res["actual_error"])
            bnds.append(res["upper_bound"])
            bcs.append(res["boundC"])
            rats.append(res["ratio"])
            xs.append(res["x"])  # your scaled x

        avg_act.append(np.mean(acts)); std_act.append(np.std(acts))
        avg_bnd.append(np.mean(bnds)); std_bnd.append(np.std(bnds))
        avg_bc.append(np.mean(bcs));   std_bc.append(np.std(bcs))
        avg_rat.append(np.mean(rats)); std_rat.append(np.std(rats))
        avg_xs.append(np.mean(xs));    std_xs.append(np.std(xs))

    return {
        "noise_levels": noise_levels,
        "avg_actuals":  avg_act,   "std_actuals":  std_act,
        "avg_bounds":   avg_bnd,   "std_bounds":   std_bnd,
        "avg_boundC":   avg_bc,    "std_boundC":   std_bc,
        "avg_ratios":   avg_rat,   "std_ratios":   std_rat,
        "avg_xs":       avg_xs,    "std_xs":       std_xs
    }

def plot_debug_results(results, title="Actual vs Bound vs Noise"):
    η = results["noise_levels"]
    plt.figure(figsize=(5,5))
    plt.errorbar(η, results["avg_actuals"],  yerr=results["std_actuals"], fmt='o--', label="Actual Bound", capsize=3)
    plt.errorbar(η, results["avg_bounds"],   yerr=results["std_bounds"],  fmt='s--', label="Our Bound", capsize=3)
    plt.errorbar(η, results["avg_boundC"],   yerr=results["std_boundC"],  fmt='^--', label="EY-N Bound", capsize=3)
    plt.xscale('log'); plt.yscale('log')
    plt.xlabel("Noise level"); plt.ylabel("Error")
    plt.title(title)
    plt.legend(); plt.grid(True, which="both", ls=":")
    plt.tight_layout(); plt.show()

if __name__ == "__main__":
    n, trials = 500, 100
    A = generate_A(n)
    results = batch_run_with_debug(A, p=10, trials=trials) #10 is computed value for linear spectrum of gap =4

    # 1) Plot everything
    plot_debug_results(results, title=f"n={n}, Rademacher Noise")
 #2) Print mean & standard deviation for each metric
    print(f"{'Noise':>8s} | {'μ_actual':>10s} {'σ_actual':>10s} | "
          f"{'μ_bound':>10s} {'σ_bound':>10s} | "
          f"{'μ_boundC':>10s} {'σ_boundC':>10s} | "
          f"{'μ_ratio':>10s} {'σ_ratio':>10s}")
    print("-"*100)
    for η, ma, sa, mb, sb, mbc, sbc, mr, sr in zip(
            results["noise_levels"],
            results["avg_actuals"],  results["std_actuals"],
            results["avg_bounds"],   results["std_bounds"],
            results["avg_boundC"],   results["std_boundC"],
            results["avg_ratios"],   results["std_ratios"]):
        print(f"{η:8.1f} | "
              f"{ma:10.4e} {sa:10.4e} | "
              f"{mb:10.4e} {sb:10.4e} | "
              f"{mbc:10.4e} {sbc:10.4e} | "
              f"{mr:10.4e} {sr:10.4e}")
