"""
Weight matrix construction for CoBET methods.
"""
import numpy as np
from .features import all_nonempty_subsets_indices


def trapezoid_weights(x):
    """
    Compute trapezoidal integration weights.
    
    Parameters
    ----------
    x : np.ndarray, shape (n,)
        Grid points
        
    Returns
    -------
    w : np.ndarray, shape (n,)
        Integration weights
    """
    w = np.zeros_like(x)
    dx = np.diff(x)
    w[1:-1] = 0.5 * (dx[:-1] + dx[1:])
    w[0] = 0.5 * (x[1] - x[0])
    w[-1] = 0.5 * (x[-1] - x[-2])
    return w


def J_numeric_K(K, t_min=1e-4, t_max=100.0, T=2001):
    """
    Compute J(K) matrix via numerical integration.
    
    The J matrix optimizes the test under specific frequency content,
    with elements J_ij = integral of product of indicator Fourier transforms.
    
    Parameters
    ----------
    K : int
        Dyadic depth
    t_min, t_max : float
        Integration domain [t_min, t_max]
    T : int
        Number of integration points
        
    Returns
    -------
    J : np.ndarray, shape (2^K - 1, 2^K - 1)
        J matrix
    subsets : list of tuples
        Corresponding subsets
    """
    subsets = all_nonempty_subsets_indices(K)
    m = len(subsets)
    
    # Logarithmic grid for integration
    t = np.logspace(np.log10(t_min), np.log10(t_max), T)
    w = trapezoid_weights(t) / (t ** 2)
    
    # Compute Fourier transform products
    inv_pows = np.array([1.0 / (2 ** r) for r in range(1, K + 1)])
    P = np.empty((m, T))
    
    for i, S in enumerate(subsets):
        vals = np.ones_like(t)
        inS = np.zeros(K, dtype=bool)
        inS[[r - 1 for r in S]] = True
        
        for r in range(K):
            ang = t * inv_pows[r]
            vals *= np.sin(ang) if inS[r] else np.cos(ang)
        
        P[i, :] = vals
    
    # Compute J = integral(P * P^T)
    J = (P * w) @ P.T
    J = 0.5 * (J + J.T)  # Symmetrize
    
    return J, subsets


def block_diag(*mats):
    """
    Create block-diagonal matrix from list of matrices.
    
    Parameters
    ----------
    *mats : np.ndarray
        Matrices to place on diagonal
        
    Returns
    -------
    out : np.ndarray
        Block-diagonal matrix
    """
    r = sum(m.shape[0] for m in mats)
    c = sum(m.shape[1] for m in mats)
    out = np.zeros((r, c), dtype=float)
    
    i = j = 0
    for m in mats:
        rr, cc = m.shape
        out[i:i+rr, j:j+cc] = m
        i += rr
        j += cc
    
    return out


def get_identity_weights(d, K, subsets=None):
    """
    Construct identity weight matrices.
    
    Parameters
    ----------
    d : int
        Dimension
    K : int
        Dyadic depth
    subsets : list of tuples, optional
        Subsets (for consistency check)
        
    Returns
    -------
    W_A, W_B, W_C : np.ndarray
        Identity weight matrices
    subsets : list of tuples
        Used subsets
    """
    if subsets is None:
        subsets = all_nonempty_subsets_indices(K)
    
    base_dim = len(subsets)
    dim_side = d * base_dim
    
    W_A = np.eye(dim_side)
    W_B = np.eye(dim_side)
    W_C = block_diag(W_A, W_B)
    
    return W_A, W_B, W_C, subsets


def get_J_weights(d, K, J_cached=None, subsets=None, reuse_J=True):
    """
    Construct J-based weight matrices.
    
    Parameters
    ----------
    d : int
        Dimension
    K : int
        Dyadic depth
    J_cached : np.ndarray, optional
        Pre-computed J matrix
    subsets : list of tuples, optional
        Subsets to use
    reuse_J : bool, default=True
        Use cached J matrix if available
        
    Returns
    -------
    W_A, W_B, W_C : np.ndarray
        J-based weight matrices
    subsets : list of tuples
        Used subsets
    J : np.ndarray
        J matrix (for caching)
    """
    if subsets is None:
        subsets = all_nonempty_subsets_indices(K)
    
    # Compute or retrieve J matrix
    if J_cached is None and reuse_J:
        J_base, subsets_J = J_numeric_K(K)
        # Reorder if needed
        if subsets_J != subsets:
            idx_map = {S: i for i, S in enumerate(subsets_J)}
            perm = [idx_map[S] for S in subsets]
            J_base = J_base[np.ix_(perm, perm)]
    elif J_cached is not None:
        J_base = J_cached
    else:
        J_base, _ = J_numeric_K(K)
    
    # Build block-diagonal structure for d coordinates
    W_A = block_diag(*([J_base] * d))
    W_B = block_diag(*([J_base] * d))
    W_C = block_diag(W_A, W_B)
    
    return W_A, W_B, W_C, subsets, J_base


def blend_weights(W_A_id, W_B_id, W_A_J, W_B_J, w_identity, w_J):
    """
    Create blended weight matrices for wa_dCoBET.
    
    Parameters
    ----------
    W_A_id, W_B_id : np.ndarray
        Identity weight matrices
    W_A_J, W_B_J : np.ndarray
        J-based weight matrices
    w_identity : float
        Weight for identity (typically in [0, 1])
    w_J : float
        Weight for J (typically 1 - w_identity)
        
    Returns
    -------
    W_A, W_B, W_C : np.ndarray
        Blended weight matrices
    """
    W_A = w_identity * W_A_id + w_J * W_A_J
    W_B = w_identity * W_B_id + w_J * W_B_J
    W_C = block_diag(W_A, W_B)
    
    return W_A, W_B, W_C
