"""
Core module for Effective Span Dimension (ESD) calculation.

This module implements the fundamental functions for calculating the Effective Span Dimension 
as defined in the paper. It provides tools for computing the H-function and the ESD value
based on signal coefficients, eigenvalues, and noise variance.

Key Functions:
- compute_H(theta, lambda_vals): Calculates the H-function by ordering eigenvalues and 
  computing cumulative sums.
- compute_esd(theta, lambda_vals, sigma2): Determines the Effective Span Dimension based on
  threshold comparison between H-values and noise level.
- compute_pc_error(y, theta_star, lambda_vals, sigma_sq, num_trials): Computes error for 
  principal component estimators along with the corresponding ESD.
"""


import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, NullFormatter, ScalarFormatter

# Core functions for Effective Span Dimension calculation

def compute_H(theta, lambda_vals):
    """
    Compute the H-function for ESD calculation.
    
    Args:
        theta: Signal coefficients.
        lambda_vals: Eigenvalues.
        
    Returns:
        H-function values.
    """
    idx = np.argsort(lambda_vals)[::-1]  # Sort in decreasing order
    theta_ord = theta[idx]
    return np.cumsum(theta_ord[::-1]**2)[::-1] / np.arange(1, len(theta_ord)+1)

def compute_esd(theta, lambda_vals, sigma2):
    """
    Calculate the Effective Span Dimension (ESD) based on Definition 3.1.
    
    Args:
        theta: Signal coefficients.
        lambda_vals: Eigenvalues.
        sigma2: Noise variance.

    Returns:
        esd: The Effective Span Dimension (denoted d^dagger in paper).
    """
    h_values = compute_H(theta, lambda_vals)
    indices = np.where(h_values <= sigma2)[0]
    if len(indices) > 0:
        return indices[0] + 1  # +1 because Python is 0-indexed
    return len(theta)  # Default if no value is found

# Helper function for simulations

def compute_pc_error(y, theta_star, lambda_vals, sigma_sq, num_trials=100):
    """
    Compute error for principal component estimator.

    Args:
        y: Observations.
        theta_star: True signal coefficients.
        lambda_vals: Eigenvalues.
        sigma_sq: Noise variance.
        num_trials: Number of Monte Carlo trials.

    Returns:
        sq_error: ||theta_hat - theta_star||^2.
        d_dagger: Number of components kept by the optimal PC estimator (ESD).
    """
    d = len(theta_star)
    if d == 0:
        return 0, 0

    d_dagger = compute_esd(theta_star**2, lambda_vals, sigma_sq)

    sq_error = 0

    # Sort indices by eigenvalues descending
    sorted_indices = np.argsort(lambda_vals)[::-1]

    # PC Estimator
    theta_hat = np.zeros(d)
    if d_dagger > 0:
        top_k_indices = sorted_indices[:d_dagger]
        theta_hat[top_k_indices] = y[top_k_indices] # Estimate top k components

    # Calculate squared error
    sq_error = np.sum((theta_hat - theta_star)**2)
    
    return sq_error, d_dagger

