"""
wa_dCoBET: Weighted Adaptive CoBET using 10-fold SNR selection.
"""
import numpy as np
from ..core import (
    get_identity_weights,
    get_J_weights,
    blend_weights,
    compute_Z_statistic,
    block_view
)
from .base import BaseCoBET


class wa_dCoBET(BaseCoBET):
    """
    Weighted Adaptive dCoBET method.
    
    This method uses 10-fold cross-validation to adaptively blend
    identity and J weights based on signal-to-noise ratio (SNR) estimates.
    
    Parameters
    ----------
    K : int, default=4
        Dyadic depth
    d : int, default=2
        Dimension
    theta : float, default=2
        Clayton copula parameter
    alpha : float, default=0.05
        Significance level
    seed : int, default=123
        Random seed
    unbiased : bool, default=True
        Use unbiased variance estimator
    n_folds : int, default=10
        Number of folds for cross-validation
    reuse_J : bool, default=True
        Cache and reuse J matrix computation
    fold_seed : int, optional
        Separate seed for fold generation (uses seed if not provided)
        
    Examples
    --------
    >>> from cobet.methods import wa_dCoBET
    >>> import numpy as np
    >>> 
    >>> # Create test instance
    >>> test = wa_dCoBET(K=4, d=2, alpha=0.05, n_folds=10)
    >>> 
    >>> # Generate some data
    >>> n = 500
    >>> X = np.random.randn(n, 2)
    >>> Y = 0.5 * X + np.random.randn(n, 2)
    >>> 
    >>> # Run test
    >>> result = test.test(X, Y)
    >>> print(f"P-value: {result['p_value']:.4f}")
    >>> print(f"Weight on identity: {result.get('w_identity', 'N/A'):.3f}")
    """
    
    def __init__(self, K=4, d=2, theta=2, alpha=0.05, seed=123, 
                 unbiased=True, n_folds=10, reuse_J=True, fold_seed=None):
        super().__init__(K=K, d=d, theta=theta, alpha=alpha, seed=seed, unbiased=unbiased)
        self.method_name = 'wa_dCoBET'
        self.n_folds = n_folds
        self.reuse_J = reuse_J
        self.fold_seed = fold_seed if fold_seed is not None else seed
        self._fold_rng = np.random.RandomState(self.fold_seed)
        
        # Pre-compute weight matrices
        self._W_id_cache = None
        self._W_J_cache = None
        
        if reuse_J:
            _, _, _, _, self._J_cached = get_J_weights(
                self.d, self.K, J_cached=None, subsets=self._subsets, reuse_J=True
            )
    
    def _get_base_weights(self):
        """Get both identity and J weight matrices."""
        if self._W_id_cache is None:
            W_A_id, W_B_id, _, _ = get_identity_weights(self.d, self.K, self._subsets)
            self._W_id_cache = (W_A_id, W_B_id)
        
        if self._W_J_cache is None:
            W_A_J, W_B_J, _, _, _ = get_J_weights(
                self.d, self.K,
                J_cached=self._J_cached if self.reuse_J else None,
                subsets=self._subsets,
                reuse_J=self.reuse_J
            )
            self._W_J_cache = (W_A_J, W_B_J)
        
        return self._W_id_cache, self._W_J_cache
    
    def _create_folds(self, n):
        """Split indices into n_folds approximately equal parts."""
        indices = self._fold_rng.permutation(n)
        folds = np.array_split(indices, self.n_folds)
        return folds
    
    def _compute_fold_Z(self, A_fold, B_fold, W_A, W_B):
        """Compute Z statistic for a fold."""
        from ..core import block_diag
        W_C = block_diag(W_A, W_B)
        Z, _, _ = compute_Z_statistic(A_fold, B_fold, W_A, W_B, W_C, self.unbiased)
        return Z
    
    def get_weights(self, A, B):
        """
        Get adaptively blended weight matrices using 10-fold selection.
        
        Parameters
        ----------
        A, B : np.ndarray, shape (p, n)
            Feature matrices
            
        Returns
        -------
        W_A, W_B, W_C : np.ndarray
            Blended weight matrices
        """
        n = A.shape[1]
        
        # Get base weight matrices
        (W_A_id, W_B_id), (W_A_J, W_B_J) = self._get_base_weights()
        
        # Create folds
        folds = self._create_folds(n)
        
        # Count selections across folds
        count_identity = 0
        count_J = 0
        
        for fold_indices in folds:
            A_fold = A[:, fold_indices]
            B_fold = B[:, fold_indices]
            
            # Compute Z for both weightings
            Z_id = self._compute_fold_Z(A_fold, B_fold, W_A_id, W_B_id)
            Z_J = self._compute_fold_Z(A_fold, B_fold, W_A_J, W_B_J)
            
            # Select based on larger SNR (Z value)
            if Z_id >= Z_J:
                count_identity += 1
            else:
                count_J += 1
        
        # Compute blend weights
        w_identity = count_identity / self.n_folds
        w_J = count_J / self.n_folds
        
        # Store for reporting
        self._last_w_identity = w_identity
        self._last_w_J = w_J
        
        # Blend weight matrices
        W_A, W_B, W_C = blend_weights(W_A_id, W_B_id, W_A_J, W_B_J, w_identity, w_J)
        
        return W_A, W_B, W_C
    
    def test(self, X, Y):
        """
        Perform adaptive independence test.
        
        Returns test result with additional weight information.
        
        Parameters
        ----------
        X, Y : np.ndarray, shape (n, d)
            Data matrices
            
        Returns
        -------
        result : dict
            Test results including 'w_identity' and 'w_J' blend weights
        """
        # Call parent test method
        result = super().test(X, Y)
        
        # Add weight information
        result['w_identity'] = self._last_w_identity
        result['w_J'] = self._last_w_J
        
        return result
