import math
import scipy.io
import numpy as np
import scipy.sparse as sp
from scipy.sparse import diags
from scipy.linalg import eigh, qr
import matplotlib.pyplot as plt
from scipy.stats import t
#We evaluate sharpness of Theorem 2.1 by comparing its bound with the actual error and the EY-N bound in the setting that
# A = BCSSTK09 (n= 1083), E = Gaussian Noise, p = 8 (computed later).


# Load the .mat file
mat = scipy.io.loadmat('bcsstk09.mat')

# Access the matrix data
data = mat['Problem']['A'][0, 0]  # Access the actual sparse matrix

# Convert the sparse matrix to a dense NumPy array
data = data.toarray().astype(np.float64)

# Now compute its singular values
_, s, _ = np.linalg.svd(data, full_matrices=False)
print(s)

# Adjust the calculation if 'r' is intended to represent something else, r is stable rank of inverse
r = sum(s[-1] / sv for sv in s)
print(f"Stable rank of inverse = {r}")

# Compute the low-rank parameter p such that the spectral tail is less than 0.05
def analyze_matrix(A):
    n = A.shape[0]

    # Step 1: Compute eigenvalues in increasing order: λ_n < ... < λ_1
    eigvals = eigh(A, eigvals_only=True)
    eigvals = np.sort(eigvals)  # eigvals[0] = λ_n, eigvals[n-1] = λ_1

    # The smallest eigenvalue
    lambda_n = eigvals[0]

    # Find smallest p such that λ_n / λ_(n-p) < 0.05
    p = None
    for i in range(1, n):
        if lambda_n / eigvals[ i] < 0.05:
            p = i
            break

    if p is None:
        print("Task 1: No such p found where λ_n / λ_(n-p) < 0.05")
    else:
        print(f"Task 1: Smallest p such that λ_n / λ_(n-p) < 0.05 is p = {p}")

    return p  # Return the computed value of p


def best_rank_p_inverse(A_inv, p):
    eigvals, eigvecs = eigh(A_inv)
    idx = np.argsort(eigvals)[::-1][:p]
    return (eigvecs[:, idx] * eigvals[idx]) @ eigvecs[:, idx].T

# Compute the actual error, the r.h.s of Theorem 2.1, The EY-N bound, and the r.h.s of Theorem B.2
def compute_perturbation_and_bound_debug(A, E, p, verbose=False):
    n = A.shape[0]
    A_inv_p = best_rank_p_inverse(np.linalg.inv(A), p)
    tilde_A_inv_p = best_rank_p_inverse(np.linalg.inv(A + E), p)
    actual_error = np.linalg.norm(tilde_A_inv_p - A_inv_p, 2)  # Compute the actual error

    # descending eigens
    eigvals, eigvecs = eigh(A)
    eigvals = np.sort(eigvals)[::-1]
    eigvecs = eigvecs[:, ::-1]

    λ_n = eigvals[-1]
    δ_np = eigvals[n - p - 1] - eigvals[n - p] if n - p - 1 >= 0 else eigvals[-1]
    δ = min(λ_n, δ_np)
    λ_np1 = eigvals[n - p]
    # compute double distance r for Theorem B.2
    r = next((r_try for r_try in range(1, n)
              if eigvals[n - r_try] >= 2 * λ_np1), n)
    λ_nr = eigvals[n - r]
    U_bot = eigvecs[:, -r:]
    # the interaction parameter, Theorem B.2
    x = np.max(np.abs(U_bot.T @ E @ U_bot))
    norm_E = np.linalg.norm(E, 2)

    # compute the bounds: bound4=Theorem 2.1, bound3=Theorem B.2, boundC=EY-N bound
    bound4 = 4 * norm_E / (λ_n ** 2) + 5 * norm_E / (δ_np * (eigvals[n - p - 1] + eigvals[n - p]))
    bound3 = (25 * norm_E) / (λ_n * λ_nr) + 4 * (r ** 2) * x / (λ_n ** 2) + 4 * (r ** 2) * x / (δ_np * (eigvals[n - p - 1] + eigvals[n - p]))
    upper_bound = bound4 / (2 * np.sqrt(n))
    boundC = (8 * norm_E / (3 * λ_n ** 2) + 2 / eigvals[n - p - 1]) / (2 * np.sqrt(n))

    return {
        "actual_error": actual_error,
        "upper_bound": upper_bound,
        "boundC": boundC,
        "ratio": actual_error / upper_bound,
        "r": r,
        "x": x,
        "norm_E": norm_E
    }


def batch_run_with_debug(A, p=None, noise_levels=None, trials=None):
    n = A.shape[0]
    if p is None:
        p = math.floor(math.log(n))

    if noise_levels is None:
        noise_levels = np.arange(1.2, 3.001, 0.2)

    if trials is None:
        trials = 10

    avg_act = []; std_act = []
    avg_bnd = []; std_bnd = []
    avg_bc = []; std_bc = []
    avg_rat = []; std_rat = []
    avg_xs = []; std_xs = []

    for η in noise_levels:
        acts, bnds, bcs, rats, xs = [], [], [], [], []
        for _ in range(trials):
            E = np.random.randn(n, n) #Generate the standard Gaussian Noise.
            E = (E + E.T) / 2
            E *= η #Scale the noise

            res = compute_perturbation_and_bound_debug(A, E, p)
            acts.append(res["actual_error"])
            bnds.append(res["upper_bound"])
            bcs.append(res["boundC"])
            rats.append(res["ratio"])
            xs.append(res["x"])

        avg_act.append(np.mean(acts)); std_act.append(np.std(acts))
        avg_bnd.append(np.mean(bnds)); std_bnd.append(np.std(bnds))
        avg_bc.append(np.mean(bcs)); std_bc.append(np.std(bcs))
        avg_rat.append(np.mean(rats)); std_rat.append(np.std(rats))
        avg_xs.append(np.mean(xs)); std_xs.append(np.std(xs))

    return {
        "noise_levels": noise_levels,
        "avg_actuals": avg_act, "std_actuals": std_act,
        "avg_bounds": avg_bnd, "std_bounds": std_bnd,
        "avg_boundC": avg_bc, "std_boundC": std_bc,
        "avg_ratios": avg_rat, "std_ratios": std_rat,
        "avg_xs": avg_xs, "std_xs": std_xs
    }


def plot_debug_results(results, title="Actual vs Bound vs Noise"):
    η = results["noise_levels"]
    plt.figure(figsize=(5, 5))
    plt.errorbar(η, results["avg_actuals"], yerr=results["std_actuals"], fmt='o--', label="Actual Bound", capsize=3)
    plt.errorbar(η, results["avg_bounds"], yerr=results["std_bounds"], fmt='s--', label="Our Bound", capsize=3)
    plt.errorbar(η, results["avg_boundC"], yerr=results["std_boundC"], fmt='^--', label="EY-N Bound", capsize=3)
    plt.xscale('linear'); plt.yscale('log')
    plt.xlabel("Noise level"); plt.ylabel("Error")
    plt.title(title)
    plt.legend(); plt.grid(True, which="both", ls=":")
    plt.tight_layout(); plt.show()


if __name__ == "__main__":
    # Use the matrix you already loaded into `data`
    A = data
    n = A.shape[0]

    # Step 1: Analyze matrix to find smallest p where λ_n / λ_(n-p) < 0.05
    p = analyze_matrix(A)

    if p is not None:
        trials = 100
        noise_levels = np.arange(1.2, 3.001, 0.2)  # Change the noise level here for different choices of A

        # now pass noise_levels explicitly
        results = batch_run_with_debug(A, p=p, noise_levels=noise_levels, trials=trials)

        # 1) Plot everything
        plot_debug_results(results, title=f"BCSSTK09 (n={n}), Gaussian Noise")

        # 2) Print mean & standard deviation for each metric
        print(f"{'Noise':>8s} | {'μ_actual':>10s} {'σ_actual':>10s} | "
              f"{'μ_bound':>10s} {'σ_bound':>10s} | "
              f"{'μ_boundC':>10s} {'σ_boundC':>10s} | "
              f"{'μ_ratio':>10s} {'σ_ratio':>10s}")
        print("-" * 100)
        for η, ma, sa, mb, sb, mbc, sbc, mr, sr in zip(
                results["noise_levels"],
                results["avg_actuals"], results["std_actuals"],
                results["avg_bounds"], results["std_bounds"],
                results["avg_boundC"], results["std_boundC"],
                results["avg_ratios"], results["std_ratios"]):
            print(f"{η:8.1f} | "
                  f"{ma:10.4e} {sa:10.4e} | "
                  f"{mb:10.4e} {sb:10.4e} | "
                  f"{mbc:10.4e} {sbc:10.4e} | "
                  f"{mr:10.4e} {sr:10.4e}")
    else:
        print("No valid p found. Skipping further analysis.")
