"""
Comprehensive utilities for implementing the ESD framework in fixed-design linear regression.

This module provides functions to create design matrices, generate signal vectors, perform
SVD computations, and implement principal component regression (PCR). It also includes
functions for measuring prediction error and creating various alignment matrices.

Function Categories:
1. Design Matrix Generation:
   - make_design(n, p, par, seed): Creates n×p design matrices with specified covariance.
   - make_beta(p, type, power): Generates dense signal vectors with power or log decay.

2. SVD & Parameter Translation:
   - cache_svd(X): Caches SVD decomposition for efficient reuse.
   - translate_theta_lambda(svd_obj, beta): Translates beta to theta/lambda using SVD.

3. Principal Component Regression:
   - pcr_factory(X): Factory pattern for creating PCR estimators.
   - fit(k, y): Computes PCR with k components.
   - fit_all(y, k_max): Computes PCR for all k values up to k_max.
   - mse_pred(X, beta_hat, beta_true): Calculates prediction MSE.

4. Alignment Matrix Construction:
   - build_A_diag_scale/exp/perm: Different strategies for diagonal alignment matrices.

5. Alignment Application:
   - apply_alignment(X, beta, A): Applies alignment matrix to design and coefficients.
   - trace_esd_path(X, beta, sigma2, alpha_seq, build_A_fn): Computes ESD along parameter path.
"""

import numpy as np
from scipy import linalg


def make_design(n, p, par=0.95, seed=None):
    """
    Create an n×p Gaussian design matrix with controlled eigenvalue decay.
    
    This function generates a design matrix where the population covariance
    eigenvalues follow a specified decay pattern - either geometric decay,
    logarithmic decay, or explicitly provided values.
    
    Parameters:
    -----------
    n : int
        Number of observations (rows)
    p : int
        Number of variables/features (columns)
    par : float or str or array-like
        Controls eigenvalue decay pattern:
        - float < 1: geometric decay (par^j)
        - 'log': logarithmic decay (1/log(j+1))
        - array-like: explicit eigenvalues (length p)
    seed : int, optional
        Random seed for reproducibility
        
    Returns:
    --------
    X : ndarray
        Design matrix (n×p) with rows drawn i.i.d. from N(0, Σ)
    """
    if seed is not None:
        np.random.seed(seed)
    
    if isinstance(par, (int, float)) or (isinstance(par, str) and par == "log"):
        if isinstance(par, str) and par == "log":
            eig_vals = 1 / np.log(np.arange(1, p+1) + 1)
        else:
            eig_vals = par ** np.arange(p)
    else:
        assert len(par) == p, "Length of par must match p"
        eig_vals = np.array(par)
    
    # Create covariance matrix
    Sigma = np.diag(eig_vals)
    
    # Generate multivariate normal samples
    X = np.random.multivariate_normal(np.zeros(p), Sigma, size=n)
    return X


def make_beta(p, type="power", power=1):
    """
    Generate a dense signal vector with controlled decay pattern.
    
    Creates a length-p coefficient vector β* following either a power law
    or logarithmic decay, useful for simulation studies with structured signals.
    
    Parameters:
    -----------
    p : int
        Length of the signal vector
    type : str
        "power" (default) or "log" decay pattern
    power : float
        For type="power", controls decay rate: β_j = j^(-power)
        Ignored for logarithmic decay
        
    Returns:
    --------
    beta : ndarray
        Signal vector of length p with specified decay structure
    """
    if type == "power":
        return 1 / np.arange(1, p+1) ** power
    elif type == "log":
        return 1 / np.log(np.arange(1, p+1) + 1)
    else:
        raise ValueError(f"Unknown type for beta generator: {type}")


# ════════════════════════════════════════════════════════════════════════════
# 1  SVD caching & θ/λ translation
# ════════════════════════════════════════════════════════════════════════════

def cache_svd(X):
    """Cache SVD decomposition of X/sqrt(n) for efficient reuse.
    
    Performs SVD on the normalized design matrix and returns components
    in a dictionary for convenient access in downstream computations.
    
    Args:
        X: Design matrix (n×p)
        
    Returns:
        Dictionary containing:
        - U: Left singular vectors (n×r)
        - V: Right singular vectors (p×r) 
        - d: Singular values
        - r: Rank (number of non-zero singular values)
    """
    u, d, vt = linalg.svd(X / np.sqrt(X.shape[0]), full_matrices=False)
    return {"U": u, "V": vt.T, "d": d, "r": len(d)}


def translate_theta_lambda(svd_obj, beta):
    """Translate regression coefficients to canonical θ/λ parametrization using SVD.
    
    Converts problem from X, β representation to U, θ, λ representation
    where θ = D·V'·β and λ = d^2 are the eigenvalues of X'X/n.
    
    Args:
        svd_obj: SVD decomposition from cache_svd
        beta: Coefficient vector in original space
        
    Returns:
        Dictionary containing:
        - theta: Coefficient vector in SVD-rotated basis
        - lambda: Eigenvalues of design covariance
    """
    theta = svd_obj["d"] * (svd_obj["V"].T @ beta)
    lambda_vals = svd_obj["d"] ** 2
    return {"theta": theta, "lambda": lambda_vals}


def cum_col(M):
    """
    Apply cumulative sum transformation to each column of a matrix.
    
    Given n×J matrix M with columns m_j, returns n×J matrix where
    column k is the sum of columns 1 through k.
    
    Args:
        M: Input matrix (n×J)
        
    Returns:
        Matrix with cumulative column sums (n×J)
    """
    J = M.shape[1]
    L = np.zeros((J, J))
    L[np.triu_indices(J)] = 1
    return M @ L


# ════════════════════════════════════════════════════════════════════════════
# 3  Principal Component Regression (single‑SVD factory)
# ════════════════════════════════════════════════════════════════════════════

def pcr_factory(X):
    """Factory for Principal Component Regression."""
    u, d, vt = linalg.svd(X, full_matrices=False)  # X = U D V^T
    U, D, V = u, d, vt.T
    r = len(D)
    
    def fit(k, y):
        """Estimator for a single k."""
        assert 1 <= k <= r, "k must be between 1 and r"
        U_k = U[:, :k]
        V_k = V[:, :k]
        coef_vec = V_k @ ((U_k.T @ y) / D[:k])
        return coef_vec
    
    def fit_all(y, k_max=None):
        """Estimators for all k = 1,...,k_max in one shot.
        Returns a p × k_max matrix where column k is β̂_k.
        """
        if k_max is None:
            k_max = r
        k_max = min(k_max, r)
        
        proj = (U[:, :k_max].T @ y) / D[:k_max]
        contrib = V[:, :k_max] * proj[np.newaxis, :]
        
        # Cumulative sum over columns
        beta_mat = cum_col(contrib)
        
        return beta_mat
    
    return {"fit": fit, "fit_all": fit_all, "r": r}


def mse_pred(X, beta_hat, beta_true):
    """Compute prediction MSE."""
    return np.mean((X @ (beta_hat - beta_true)) ** 2)


# ════════════════════════════════════════════════════════════════════════════
# 4  Alignment‑matrix builders  (non‑orthogonal → can change ESD)
# ════════════════════════════════════════════════════════════════════════════

def build_A_diag_scale(p, min=1, max=2):
    """Diagonal scaling diag(u_i) with u_i ~ Uniform(min,max)."""
    return np.diag(np.random.uniform(min, max, p))


def build_A_diag_exp(p, alpha):
    """Smooth exponential diagonal stretch diag(exp(α t_i)).
    t_i equally spaced in [-½,½]; α controls severity (0 = identity).
    """
    t_vec = np.linspace(-0.5, 0.5, p)
    return np.diag(np.exp(alpha * t_vec))


def build_A_diag_perm(beta, s_hi=2, s_lo=0.5):
    """Two‑level diagonal + permutation aligning largest β components with small scales."""
    p = len(beta)
    ord_idx = np.argsort(np.abs(beta))[::-1]  # Order by decreasing absolute value
    scales = np.full(p, s_hi)
    scales[ord_idx] = np.linspace(s_lo, s_hi, p)[ord_idx]
    return np.diag(scales)


# ════════════════════════════════════════════════════════════════════════════
# 5  Alignment application & ESD tracing
# ════════════════════════════════════════════════════════════════════════════

def apply_alignment(X, beta, A):
    """Apply alignment matrix A to X and beta."""
    return {"X_A": X @ A, "beta_A": linalg.solve(A, beta)}


def trace_esd_path(X, beta, sigma2, alpha_seq, build_A_fn):
    """Compute ESD along path α ↦ A(α)."""
    from esd_modular_functions import compute_esd
    
    out = np.zeros(len(alpha_seq))
    n = X.shape[0]
    
    for idx, alpha in enumerate(alpha_seq):
        A = build_A_fn(X.shape[1], alpha)
        aligned = apply_alignment(X, beta, A)
        svd_obj = cache_svd(aligned["X_A"])
        tl = translate_theta_lambda(svd_obj, aligned["beta_A"])
        out[idx] = compute_esd(tl["theta"], tl["lambda"], sigma2 / n)
    
    return {"alpha": alpha_seq, "esd": out}