"""
Simulation runner for power analysis and Type I error estimation.
"""
import numpy as np
import pandas as pd
from scipy.stats import norm


class SimulationRunner:
    """
    Orchestrate Monte Carlo simulations for CoBET methods.
    
    Parameters
    ----------
    method : BaseCoBET
        Test method instance (CoBET, dCoBET, or wa_dCoBET)
    n_list : list of int
        Sample sizes to test
    theta : float
        Clayton copula parameter
    K : int
        Dyadic depth
    d : int
        Dimension
    alpha : float
        Significance level
    R_eval : int
        Number of Monte Carlo replications
    seed : int
        Random seed
    transforms : tuple of str
        Transform types to test
    b_config_by_n : dict
        Mapping from n to transform-specific b values
    unbiased : bool
        Use unbiased variance estimator
    """
    
    def __init__(self, method, n_list, theta, K, d, alpha, R_eval, seed,
                 transforms, b_config_by_n, unbiased=True):
        self.method = method
        self.n_list = n_list
        self.theta = theta
        self.K = K
        self.d = d
        self.alpha = alpha
        self.R_eval = R_eval
        self.seed = seed
        self.transforms = transforms
        self.b_config_by_n = b_config_by_n
        self.unbiased = unbiased
        
        # Critical value for one-sided test
        self.z_critical = norm.ppf(1 - alpha)
    
    def run_single_replication(self, n, transform_key, b):
        """
        Run one Monte Carlo replication.
        
        Parameters
        ----------
        n : int
            Sample size
        transform_key : str
            Transform type
        b : float
            Dependence parameter
            
        Returns
        -------
        result : dict
            Single replication result
        """
        # Generate data
        X, Y = self.method.generate_data(n, transform_key, b)
        
        # Run test
        test_result = self.method.test(X, Y)
        
        return test_result
    
    def run_setting(self, n, transform_key, b):
        """
        Run R_eval replications for one setting.
        
        Parameters
        ----------
        n : int
            Sample size
        transform_key : str
            Transform type
        b : float
            Dependence parameter
            
        Returns
        -------
        result : dict
            Summary statistics for this setting
        """
        rejections = 0
        Z_values = []
        
        for rep in range(self.R_eval):
            result = self.run_single_replication(n, transform_key, b)
            
            if result['reject']:
                rejections += 1
            
            Z_values.append(result['Z'])
        
        # Compute summary
        power_or_type1 = rejections / self.R_eval
        metric = 'typeI' if b == 0 else 'power'
        
        summary = {
            'method': self.method.method_name,
            'n': n,
            'transform': transform_key,
            'b': float(b),
            'metric': metric,
            'value': power_or_type1,
            'd': self.d,
            'K': self.K,
            'theta': self.theta,
            'alpha': self.alpha,
            'R_eval': self.R_eval,
            'seed': self.seed,
            'Z_mean': float(np.mean(Z_values)),
            'Z_std': float(np.std(Z_values, ddof=1)),
        }
        
        return summary
    
    def run_all(self, report_typeI=True):
        """
        Run simulations for all n, transforms, and b values.
        
        Parameters
        ----------
        report_typeI : bool, default=True
            Whether to include Type I error (b=0) in results
            
        Returns
        -------
        results : pd.DataFrame
            All simulation results
        """
        all_results = []
        
        for n in self.n_list:
            if n not in self.b_config_by_n:
                raise ValueError(f"Missing b_config for n={n}")
            
            b_config = self.b_config_by_n[n]
            
            for transform_key in self.transforms:
                if transform_key not in b_config:
                    continue
                
                b_list = b_config[transform_key]
                if not isinstance(b_list, (list, tuple, np.ndarray)):
                    b_list = [b_list]
                
                # Type I error (b=0)
                if report_typeI:
                    print(f"Running n={n}, transform={transform_key}, b=0 (Type I)...")
                    result = self.run_setting(n, transform_key, 0.0)
                    all_results.append(result)
                
                # Power (b > 0)
                for b in b_list:
                    if b == 0:
                        continue
                    print(f"Running n={n}, transform={transform_key}, b={b} (Power)...")
                    result = self.run_setting(n, transform_key, b)
                    all_results.append(result)
        
        return pd.DataFrame(all_results)


def run_power_analysis(method, n_list, theta, K, d, alpha, R_eval, seed,
                       transforms, b_config_by_n, unbiased=True):
    """
    Convenience function to run power analysis.
    
    Parameters
    ----------
    method : BaseCoBET or str
        Test method instance or name ('cobet', 'dcobet', 'wa_dcobet')
    n_list : list of int
        Sample sizes
    theta : float
        Clayton copula parameter
    K : int
        Dyadic depth
    d : int
        Dimension
    alpha : float
        Significance level
    R_eval : int
        Number of replications
    seed : int
        Random seed
    transforms : tuple of str
        Transform types
    b_config_by_n : dict
        b configurations
    unbiased : bool
        Use unbiased variance
        
    Returns
    -------
    results : pd.DataFrame
        Simulation results
    """
    # Import here to avoid circular imports
    from ..methods import CoBET, dCoBET, wa_dCoBET
    
    # Create method instance if string provided
    if isinstance(method, str):
        method_map = {
            'cobet': CoBET,
            'dcobet': dCoBET,
            'wa_dcobet': wa_dCoBET,
        }
        method_class = method_map.get(method.lower())
        if method_class is None:
            raise ValueError(f"Unknown method: {method}")
        method = method_class(K=K, d=d, theta=theta, alpha=alpha, seed=seed, unbiased=unbiased)
    
    # Create runner
    runner = SimulationRunner(
        method=method,
        n_list=n_list,
        theta=theta,
        K=K,
        d=d,
        alpha=alpha,
        R_eval=R_eval,
        seed=seed,
        transforms=transforms,
        b_config_by_n=b_config_by_n,
        unbiased=unbiased
    )
    
    # Run simulations
    results = runner.run_all(report_typeI=True)
    
    return results
