"""
Base class for CoBET methods.
"""
from abc import ABC, abstractmethod
import numpy as np
from scipy.stats import norm

from ..core import (
    clayton_copula_sample_nd,
    apply_transform,
    build_AB_features,
    all_nonempty_subsets_indices,
    compute_full_T,
    plugin_var_tildeT1,
    compute_Z_statistic,
)


class BaseCoBET(ABC):
    """
    Abstract base class for CoBET family of tests.
    
    Parameters
    ----------
    K : int, default=4
        Dyadic depth (number of bits in binary expansion)
    d : int, default=2
        Dimension of data
    theta : float, default=2
        Clayton copula parameter for data generation
    alpha : float, default=0.05
        Significance level
    seed : int, default=123
        Random seed for reproducibility
    unbiased : bool, default=True
        Use unbiased variance estimator
    """
    
    def __init__(self, K=4, d=2, theta=2, alpha=0.05, seed=123, unbiased=True):
        self.K = K
        self.d = d
        self.theta = theta
        self.alpha = alpha
        self.seed = seed
        self.unbiased = unbiased
        
        # Initialize subsets and cached components
        self._subsets = all_nonempty_subsets_indices(K)
        self._J_cached = None
        self._rng = np.random.RandomState(seed)
        
        # Method name (to be set by subclasses)
        self.method_name = "BaseCoBET"
    
    @abstractmethod
    def get_weights(self, A, B):
        """
        Get weight matrices for the test.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        A, B : np.ndarray, shape (p, n)
            Feature matrices
            
        Returns
        -------
        W_A, W_B, W_C : np.ndarray
            Weight matrices
        """
        pass
    
    def build_features(self, X, Y):
        """
        Build feature matrices from data.
        
        Parameters
        ----------
        X, Y : np.ndarray, shape (n, d)
            Data matrices
            
        Returns
        -------
        A, B : np.ndarray, shape (d * (2^K - 1), n)
            Feature matrices
        """
        return build_AB_features(X, Y, self.K, self._subsets)
    
    def compute_test_statistic(self, A, B, W_A, W_B, W_C):
        """
        Compute test statistic T.
        
        Parameters
        ----------
        A, B : np.ndarray
            Feature matrices
        W_A, W_B, W_C : np.ndarray
            Weight matrices
            
        Returns
        -------
        T : float
            Test statistic value
        """
        return compute_full_T(A, B, W_A, W_B, W_C)
    
    def compute_variance(self, A, B, W_A, W_B):
        """
        Compute plug-in variance estimate.
        
        Parameters
        ----------
        A, B : np.ndarray
            Feature matrices
        W_A, W_B : np.ndarray
            Weight matrices
            
        Returns
        -------
        var_T : float
            Variance estimate
        """
        return plugin_var_tildeT1(A, B, W_A, W_B, unbiased=self.unbiased)
    
    def test(self, X, Y):
        """
        Perform independence test on data X and Y.
        
        Parameters
        ----------
        X, Y : np.ndarray, shape (n, d)
            Data matrices
            
        Returns
        -------
        result : dict
            Dictionary containing:
            - 'statistic': Test statistic T
            - 'Z': Standardized statistic
            - 'p_value': One-sided p-value
            - 'reject': Whether to reject null hypothesis
        """
        # Build features
        A, B = self.build_features(X, Y)
        
        # Get weight matrices
        W_A, W_B, W_C = self.get_weights(A, B)
        
        # Compute statistic and variance
        T = self.compute_test_statistic(A, B, W_A, W_B, W_C)
        var_T = self.compute_variance(A, B, W_A, W_B)
        
        # Standardize
        Z = T / np.sqrt(max(var_T, 1e-16))
        
        # One-sided p-value (reject for large Z)
        p_value = 1.0 - norm.cdf(Z)
        
        return {
            'statistic': T,
            'Z': Z,
            'variance': var_T,
            'p_value': p_value,
            'reject': p_value < self.alpha,
            'method': self.method_name
        }
    
    def generate_data(self, n, transform_key, b):
        """
        Generate synthetic data for simulations.
        
        Parameters
        ----------
        n : int
            Sample size
        transform_key : str
            Transform type ('trigU', 'expquad', 'linear', 'logquad')
        b : float or array-like
            Dependence parameter
            
        Returns
        -------
        X, Y : np.ndarray, shape (n, d)
            Generated data
        """
        # Sample from Clayton copula
        u = clayton_copula_sample_nd(n, self.theta, self.d, rng=self._rng)
        v = clayton_copula_sample_nd(n, self.theta, self.d, rng=self._rng)
        
        # Apply transform
        X, Y = apply_transform(u, v, b, transform_key)
        
        return X, Y
    
    def __repr__(self):
        return (f"{self.method_name}(K={self.K}, d={self.d}, "
                f"theta={self.theta}, alpha={self.alpha})")
