import math
import scipy.io
import numpy as np
import scipy.sparse as sp
from scipy.sparse import diags
from scipy.linalg import eigh, qr
import matplotlib.pyplot as plt
from scipy.stats import t
#We evaluate sharpness of Theorem 2.1 by comparing its bound with the actual error and the EY-N bound in the setting that
# A = Census (n= 69), E = Rademacher Noise, p = 17 (computed later).

def file_to_matrix(file_path):
    """Reads a file containing integers and returns a matrix (2D list).

    Args:
        file_path: Path to the input file.

    Returns:
        A 2D list representing the matrix.
    """

    matrix = []
    with open(file_path, 'r') as file:
        next(file)  # Skip the first line (header)
        for line in file:
            # Split the line into individual numbers and convert them to integers,
            # skipping the first element (column)
            row = []
            for num in line.strip().split(','):
                try:
                    row.append(int(num))
                except ValueError:
                    # Replace non-numeric values with 0
                    row.append(0)  # Or any other default value you prefer
            if row:  # Add the row to the matrix if it's not empty
                matrix.append(row)

    # Get the maximum row length
    max_len = max(len(row) for row in matrix)

    # Pad shorter rows with 0s to make all rows the same length
    for row in matrix:
        if len(row) < max_len:
            row.extend([0] * (max_len - len(row)))

    return matrix

# Example usage:
file_path = 'USCensus1990.data.txt'  # Replace with your file path
matrix = file_to_matrix(file_path)

# Convert the matrix to a NumPy array
matrix_np = np.array(matrix)

# Calculate singular values
# Removing compute_u and compute_vh arguments for compatibility
U, s, Vh = np.linalg.svd(matrix_np, full_matrices=False)
# full_matrices=False has the same effect as compute_u=False, compute_vh=False in later versions

# Calculate r -  stable rank of inverse A.
r = sum(s[-1]**2 / sv**2 for sv in s)
print(f"Stable rank of Inverse = {r}")




# Calculate the covariance matrix A = M^T * M
A = np.dot(matrix_np.T, matrix_np)

# Compute the low-rank parameter p such that the spectral tail is less than 0.05
def analyze_matrix(A):
    n = A.shape[0]

    # Step 1: Compute eigenvalues in increasing order: λ_n < ... < λ_1
    eigvals = eigh(A, eigvals_only=True)
    eigvals = np.sort(eigvals)  # eigvals[0] = λ_n, eigvals[n-1] = λ_1

    # The smallest eigenvalue
    lambda_n = eigvals[0]

    # Find smallest p such that λ_n / λ_(n-p) < 0.05
    p = None
    for i in range(1, n):
        if lambda_n / eigvals[ i] < 0.05:
            p = i
            break

    if p is None:
        print("Task 1: No such p found where λ_n / λ_(n-p) < 0.05")
    else:
        print(f"Task 1: Smallest p such that λ_n / λ_(n-p) < 0.05 is p = {p}")

    return p  # Return the computed value of p

def best_rank_p_inverse(A_inv, p):
    eigvals, eigvecs = eigh(A_inv)
    idx = np.argsort(eigvals)[::-1][:p]
    return (eigvecs[:, idx] * eigvals[idx]) @ eigvecs[:, idx].T

def compute_perturbation_and_bound_debug(A, E, p, verbose=False):
    n = A.shape[0]
    A_inv_p       = best_rank_p_inverse(np.linalg.inv(A), p)
    tilde_A_inv_p = best_rank_p_inverse(np.linalg.inv(A+E), p)
    actual_error  = np.linalg.norm(tilde_A_inv_p - A_inv_p, 2)

    # descending eigens
    eigvals, eigvecs = eigh(A)
    eigvals = np.sort(eigvals)[::-1]
    eigvecs = eigvecs[:, ::-1]

    λ_n     = eigvals[-1]
    δ_np    = eigvals[n-p-1] - eigvals[n-p] if n-p-1>=0 else eigvals[-1]
    δ       = min(λ_n, δ_np)
    λ_np1   = eigvals[n-p]
    # compute r for Theorem B.2
    r = next((r_try for r_try in range(1,n)
              if eigvals[n-r_try] >= 2*λ_np1), n)
    λ_nr    = eigvals[n-r]
    U_bot   = eigvecs[:, -r:]
    # The interaction parameter x for Theorem B.2
    x       = np.max(np.abs(U_bot.T @ E @ U_bot))
    norm_E  = np.linalg.norm(E, 2)
# compute the bounds: bound4= the r.h.s of Theorem 2.1, bound3= the r.h.s of Theorem B.2, boundC= EY-N bound
    bound4       = 4*norm_E/(λ_n**2) + 5*norm_E/(δ_np*(eigvals[n-p-1] + eigvals[n-p]))
    bound3       = (25*norm_E)/(λ_n*λ_nr) + 4*(r**2)*x/(λ_n**2) + 4*(r**2)*x/(δ_np*(eigvals[n-p-1]+eigvals[n-p]))
    upper_bound  =  bound4/(2*np.sqrt(n))
    boundC       = (8*norm_E/(3*λ_n**2) + 2/eigvals[n-p-1])/(2*np.sqrt(n))

    return {
        "actual_error": actual_error,
        "upper_bound":  upper_bound,
        "boundC":       boundC,
        "ratio":        actual_error/upper_bound,
        "r":            r,
        "x":            x,
        "norm_E":       norm_E
    }

def batch_run_with_debug(A, p=None, noise_levels=None, trials=None):
    n = A.shape[0]
    # only override p if the caller gave us None
    if p is None:
        p = math.floor(math.log(n))

    # only set a default grid if caller passed None
    if noise_levels is None:
        noise_levels = np.arange(1.2, 3.001, 0.2)

    # same for trials
    if trials is None:
        trials = 10

    # storage
    avg_act = []; std_act = []
    avg_bnd = []; std_bnd = []
    avg_bc  = []; std_bc  = []
    avg_rat = []; std_rat = []
    avg_xs  = []; std_xs  = []

    for η in noise_levels:
        acts, bnds, bcs, rats, xs = [], [], [], [], []
        for _ in range(trials):
            E = np.random.randn(n, n)  # Generate the standard Gaussian Noise.
            E = (E + E.T) / 2
            E *= η  # Scale the noise

            res = compute_perturbation_and_bound_debug(A, E, p)
            acts.append(res["actual_error"])
            bnds.append(res["upper_bound"])
            bcs.append(res["boundC"])
            rats.append(res["ratio"])
            xs.append(res["x"])  # your scaled x

        avg_act.append(np.mean(acts)); std_act.append(np.std(acts))
        avg_bnd.append(np.mean(bnds)); std_bnd.append(np.std(bnds))
        avg_bc.append(np.mean(bcs));   std_bc.append(np.std(bcs))
        avg_rat.append(np.mean(rats)); std_rat.append(np.std(rats))
        avg_xs.append(np.mean(xs));    std_xs.append(np.std(xs))

    return {
        "noise_levels": noise_levels,
        "avg_actuals":  avg_act,   "std_actuals":  std_act,
        "avg_bounds":   avg_bnd,   "std_bounds":   std_bnd,
        "avg_boundC":   avg_bc,    "std_boundC":   std_bc,
        "avg_ratios":   avg_rat,   "std_ratios":   std_rat,
        "avg_xs":       avg_xs,    "std_xs":       std_xs
    }

def plot_debug_results(results, title="Actual vs Bound vs Noise"):
    η = results["noise_levels"]
    plt.figure(figsize=(5,5))
    plt.errorbar(η, results["avg_actuals"],  yerr=results["std_actuals"], fmt='o--', label="Actual Bound", capsize=3)
    plt.errorbar(η, results["avg_bounds"],   yerr=results["std_bounds"],  fmt='s--', label="Our Bound", capsize=3)
    plt.errorbar(η, results["avg_boundC"],   yerr=results["std_boundC"],  fmt='^--', label="EY-N Bound", capsize=3)
    plt.xscale('linear'); plt.yscale('log')
    plt.xlabel("Noise level"); plt.ylabel("Error")
    plt.title(title)
    plt.legend(); plt.grid(True, which="both", ls=":")
    plt.tight_layout(); plt.show()


if __name__ == "__main__":
    # Use the matrix you already loaded into `data`
    A = np.dot(matrix_np.T, matrix_np)
    n = A.shape[0]

    trials = 100
    # Analyze matrix to find smallest p where λ_n / λ_(n-p) < 0.05
    p      = analyze_matrix(A)

    # new noise grid: 1.5, 2, ...,6
    noise_levels = np.arange(1, 6.001, 0.5)

    # now pass noise_levels explicitly
    results = batch_run_with_debug(A,
                                   p=p,
                                   noise_levels=noise_levels,
                                   trials=trials)

    # 1) Plot everything
    plot_debug_results(
        results,
        title=f"1990 US Census (n={n}), Rademacher Noise"
    )
#2) Print mean & standard deviation for each metric
    print(f"{'Noise':>8s} | {'μ_actual':>10s} {'σ_actual':>10s} | "
          f"{'μ_bound':>10s} {'σ_bound':>10s} | "
          f"{'μ_boundC':>10s} {'σ_boundC':>10s} | "
          f"{'μ_ratio':>10s} {'σ_ratio':>10s}")
    print("-"*100)
    for η, ma, sa, mb, sb, mbc, sbc, mr, sr in zip(
            results["noise_levels"],
            results["avg_actuals"],  results["std_actuals"],
            results["avg_bounds"],   results["std_bounds"],
            results["avg_boundC"],   results["std_boundC"],
            results["avg_ratios"],   results["std_ratios"]):
        print(f"{η:8.1f} | "
              f"{ma:10.4e} {sa:10.4e} | "
              f"{mb:10.4e} {sb:10.4e} | "
              f"{mbc:10.4e} {sbc:10.4e} | "
              f"{mr:10.4e} {sr:10.4e}")
