#!/usr/bin/env python3
"""
PCMCI Multiple Testing Correction
Applies Bonferroni and FDR corrections to causal discovery p-values
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
import logging
from pathlib import Path
import json

try:
    from statsmodels.stats.multitest import multipletests
    STATSMODELS_AVAILABLE = True
except ImportError:
    STATSMODELS_AVAILABLE = False
    logging.warning("statsmodels not available - using manual Bonferroni correction")

try:
    from tigramite import data_processing as pp
    from tigramite.pcmci import PCMCI
    from tigramite.independence_tests import ParCorr, GPDC, CMIknn
    TIGRAMITE_AVAILABLE = True
except ImportError:
    TIGRAMITE_AVAILABLE = False
    logging.warning("Tigramite not available - PCMCI features disabled")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class PCMCIWithMultipleTestingCorrection:
    """
    Wrapper for PCMCI with proper multiple testing correction
    Applies Bonferroni or FDR (Benjamini-Hochberg, Benjamini-Yekutieli)
    """
    
    def __init__(self,
                 dataframe: Optional[pd.DataFrame] = None,
                 var_names: Optional[List[str]] = None,
                 cond_ind_test: str = 'parcorr',
                 verbosity: int = 0):
        """
        Initialize PCMCI with multiple testing correction
        
        Args:
            dataframe: Time series data (n_timesteps × n_variables)
            var_names: Variable names
            cond_ind_test: Conditional independence test ('parcorr', 'gpdc', 'cmiknn')
            verbosity: Verbosity level
        """
        self.dataframe_obj = None
        self.pcmci = None
        self.var_names = var_names
        self.verbosity = verbosity
        
        # Results storage
        self.results = None
        self.corrected_results = None
        
        # Correction parameters
        self.correction_method = None
        self.alpha_original = 0.05
        self.alpha_corrected = None
        
        if dataframe is not None and TIGRAMITE_AVAILABLE:
            self._initialize_pcmci(dataframe, cond_ind_test)
    
    def _initialize_pcmci(self, dataframe: pd.DataFrame, cond_ind_test: str):
        """Initialize PCMCI object"""
        # Convert to Tigramite dataframe
        data = dataframe.values
        self.dataframe_obj = pp.DataFrame(
            data,
            var_names=self.var_names if self.var_names else dataframe.columns.tolist()
        )
        
        # Select conditional independence test
        if cond_ind_test == 'parcorr':
            cond_ind_test_obj = ParCorr(significance='analytic')
        elif cond_ind_test == 'gpdc':
            cond_ind_test_obj = GPDC(significance='analytic')
        elif cond_ind_test == 'cmiknn':
            cond_ind_test_obj = CMIknn(significance='shuffle_test')
        else:
            logger.warning(f"Unknown test {cond_ind_test}, using ParCorr")
            cond_ind_test_obj = ParCorr(significance='analytic')
        
        # Create PCMCI object
        self.pcmci = PCMCI(
            dataframe=self.dataframe_obj,
            cond_ind_test=cond_ind_test_obj,
            verbosity=self.verbosity
        )
        
        logger.info(f"Initialized PCMCI with {cond_ind_test} test")
    
    def run_pcmci_with_correction(self,
                                  tau_max: int = 5,
                                  pc_alpha: float = 0.05,
                                  correction_method: str = 'fdr_bh',
                                  alpha_level: float = 0.05) -> Dict:
        """
        Run PCMCI with multiple testing correction
        
        Args:
            tau_max: Maximum time lag
            pc_alpha: Significance level for PC algorithm
            correction_method: 'bonferroni', 'fdr_bh', 'fdr_by', 'none'
            alpha_level: Family-wise or FDR control level
            
        Returns:
            Dictionary with corrected results
        """
        if not TIGRAMITE_AVAILABLE or self.pcmci is None:
            logger.error("PCMCI not available or not initialized")
            return {}
        
        logger.info(f"Running PCMCI with tau_max={tau_max}, alpha={pc_alpha}")
        logger.info(f"Correction method: {correction_method}, level={alpha_level}")
        
        # Run standard PCMCI
        results = self.pcmci.run_pcmci(
            tau_max=tau_max,
            pc_alpha=pc_alpha
        )
        
        self.results = results
        self.correction_method = correction_method
        self.alpha_original = alpha_level
        
        # Extract p-values
        p_matrix = results['p_matrix']  # Shape: (n_vars, n_vars, tau_max+1)
        val_matrix = results['val_matrix']  # Test statistics
        
        # Apply multiple testing correction
        corrected_results = self._apply_multiple_testing_correction(
            p_matrix,
            val_matrix,
            correction_method,
            alpha_level
        )
        
        self.corrected_results = corrected_results
        
        return corrected_results
    
    def _apply_multiple_testing_correction(self,
                                          p_matrix: np.ndarray,
                                          val_matrix: np.ndarray,
                                          correction_method: str,
                                          alpha: float) -> Dict:
        """
        Apply multiple testing correction to PCMCI p-values
        
        Args:
            p_matrix: P-values (n_vars × n_vars × tau_max+1)
            val_matrix: Test statistics
            correction_method: Correction method
            alpha: Significance level
            
        Returns:
            Dictionary with corrected p-values and significance
        """
        n_vars, _, tau_max_plus_1 = p_matrix.shape
        
        # Flatten p-values (exclude diagonal at lag 0 - variable can't cause itself)
        p_values_flat = []
        indices = []
        
        for i in range(n_vars):
            for j in range(n_vars):
                for tau in range(tau_max_plus_1):
                    # Skip diagonal at lag 0
                    if tau == 0 and i == j:
                        continue
                    
                    p_val = p_matrix[i, j, tau]
                    if not np.isnan(p_val):
                        p_values_flat.append(p_val)
                        indices.append((i, j, tau))
        
        p_values_flat = np.array(p_values_flat)
        n_tests = len(p_values_flat)
        
        logger.info(f"Total number of tests: {n_tests}")
        
        # Apply correction
        if correction_method == 'bonferroni':
            p_values_corrected = self._bonferroni_correction(p_values_flat, alpha)
            reject = p_values_corrected < alpha
            alpha_corrected = alpha / n_tests
            
        elif correction_method.startswith('fdr'):
            if STATSMODELS_AVAILABLE:
                method = 'fdr_bh' if 'bh' in correction_method else 'fdr_by'
                reject, p_values_corrected, _, alpha_corrected = multipletests(
                    p_values_flat,
                    alpha=alpha,
                    method=method
                )
            else:
                # Manual FDR correction
                p_values_corrected, reject, alpha_corrected = self._manual_fdr_correction(
                    p_values_flat,
                    alpha,
                    method='bh' if 'bh' in correction_method else 'by'
                )
        
        elif correction_method == 'none':
            p_values_corrected = p_values_flat
            reject = p_values_flat < alpha
            alpha_corrected = alpha
        
        else:
            logger.warning(f"Unknown correction method: {correction_method}, using Bonferroni")
            p_values_corrected = self._bonferroni_correction(p_values_flat, alpha)
            reject = p_values_corrected < alpha
            alpha_corrected = alpha / n_tests
        
        # Reconstruct corrected p-matrix
        p_matrix_corrected = np.full_like(p_matrix, np.nan)
        reject_matrix = np.zeros_like(p_matrix, dtype=bool)
        
        for idx, (i, j, tau) in enumerate(indices):
            p_matrix_corrected[i, j, tau] = p_values_corrected[idx]
            reject_matrix[i, j, tau] = reject[idx]
        
        # Count significant links
        n_significant_original = np.sum(p_values_flat < alpha)
        n_significant_corrected = np.sum(reject)
        
        logger.info(f"Significant links (original): {n_significant_original}/{n_tests} "
                   f"({n_significant_original/n_tests*100:.1f}%)")
        logger.info(f"Significant links (corrected): {n_significant_corrected}/{n_tests} "
                   f"({n_significant_corrected/n_tests*100:.1f}%)")
        logger.info(f"Corrected significance threshold: {alpha_corrected:.6f}")
        
        return {
            'p_matrix_original': p_matrix,
            'p_matrix_corrected': p_matrix_corrected,
            'reject_matrix': reject_matrix,
            'val_matrix': val_matrix,
            'correction_method': correction_method,
            'alpha_original': alpha,
            'alpha_corrected': alpha_corrected,
            'n_tests': n_tests,
            'n_significant_original': n_significant_original,
            'n_significant_corrected': n_significant_corrected,
            'false_positive_rate_reduction': 1 - (n_significant_corrected / max(n_significant_original, 1))
        }
    
    def _bonferroni_correction(self, p_values: np.ndarray, alpha: float) -> np.ndarray:
        """
        Apply Bonferroni correction: p_corrected = min(p * n_tests, 1.0)
        
        Args:
            p_values: Original p-values
            alpha: Significance level
            
        Returns:
            Corrected p-values
        """
        n_tests = len(p_values)
        p_corrected = np.minimum(p_values * n_tests, 1.0)
        
        logger.info(f"Bonferroni correction applied: n_tests={n_tests}, "
                   f"corrected alpha={alpha/n_tests:.6f}")
        
        return p_corrected
    
    def _manual_fdr_correction(self,
                               p_values: np.ndarray,
                               alpha: float,
                               method: str = 'bh') -> Tuple[np.ndarray, np.ndarray, float]:
        """
        Manual FDR correction (Benjamini-Hochberg or Benjamini-Yekutieli)
        
        Args:
            p_values: Original p-values
            alpha: Desired FDR level
            method: 'bh' or 'by'
            
        Returns:
            Tuple of (corrected_p_values, reject, threshold)
        """
        n_tests = len(p_values)
        
        # Sort p-values
        sorted_idx = np.argsort(p_values)
        sorted_p = p_values[sorted_idx]
        
        # Benjamini-Hochberg or Benjamini-Yekutieli
        if method == 'bh':
            # BH: threshold_i = (i/n_tests) * alpha
            thresholds = np.arange(1, n_tests + 1) / n_tests * alpha
        else:  # 'by'
            # BY: threshold_i = (i/n_tests) * alpha / c(n_tests)
            # c(n_tests) = sum(1/i for i=1..n_tests) ≈ log(n_tests) + 0.5772
            c_n = np.sum(1.0 / np.arange(1, n_tests + 1))
            thresholds = np.arange(1, n_tests + 1) / n_tests * alpha / c_n
        
        # Find largest i where p_(i) <= threshold_i
        significant = sorted_p <= thresholds
        if np.any(significant):
            max_idx = np.where(significant)[0][-1]
            threshold = thresholds[max_idx]
        else:
            threshold = 0.0
        
        # Create reject array
        reject = p_values <= threshold
        
        # Corrected p-values
        p_corrected = np.minimum(p_values * n_tests / np.arange(1, n_tests + 1)[sorted_idx.argsort()], 1.0)
        
        logger.info(f"FDR-{method.upper()} correction: threshold={threshold:.6f}")
        
        return p_corrected, reject, threshold
    
    def get_significant_links(self, use_corrected: bool = True) -> List[Dict]:
        """
        Extract significant causal links
        
        Args:
            use_corrected: Use corrected p-values if True
            
        Returns:
            List of significant links with metadata
        """
        if self.corrected_results is None:
            logger.warning("No corrected results available")
            return []
        
        if use_corrected:
            p_matrix = self.corrected_results['p_matrix_corrected']
            reject_matrix = self.corrected_results['reject_matrix']
            alpha = self.corrected_results['alpha_corrected']
        else:
            p_matrix = self.corrected_results['p_matrix_original']
            alpha = self.corrected_results['alpha_original']
            reject_matrix = p_matrix < alpha
        
        val_matrix = self.corrected_results['val_matrix']
        var_names = self.var_names if self.var_names else [f"X{i}" for i in range(p_matrix.shape[0])]
        
        links = []
        n_vars, _, tau_max_plus_1 = p_matrix.shape
        
        for i in range(n_vars):
            for j in range(n_vars):
                for tau in range(tau_max_plus_1):
                    if tau == 0 and i == j:
                        continue
                    
                    if reject_matrix[i, j, tau]:
                        links.append({
                            'source': var_names[j],
                            'target': var_names[i],
                            'lag': tau,
                            'test_statistic': val_matrix[i, j, tau],
                            'p_value': p_matrix[i, j, tau],
                            'significant': True
                        })
        
        # Sort by p-value
        links.sort(key=lambda x: x['p_value'])
        
        logger.info(f"Found {len(links)} significant links (corrected={use_corrected})")
        
        return links
    
    def generate_report(self, output_path: Optional[str] = None) -> str:
        """
        Generate summary report of PCMCI results with correction
        
        Args:
            output_path: Optional file path to save report
            
        Returns:
            Report string
        """
        if self.corrected_results is None:
            return "No results available.\n"
        
        report_lines = [
            "="*80,
            "PCMCI CAUSAL DISCOVERY WITH MULTIPLE TESTING CORRECTION",
            "="*80,
            "",
            f"Correction method: {self.corrected_results['correction_method'].upper()}",
            f"Original significance level (α): {self.corrected_results['alpha_original']:.4f}",
            f"Corrected significance level: {self.corrected_results['alpha_corrected']:.6f}",
            f"Total number of tests: {self.corrected_results['n_tests']}",
            "",
            "RESULTS:",
            f"  Significant links (original): {self.corrected_results['n_significant_original']}",
            f"  Significant links (corrected): {self.corrected_results['n_significant_corrected']}",
            f"  False positive reduction: {self.corrected_results['false_positive_rate_reduction']*100:.1f}%",
            "",
            "="*80,
            "TOP SIGNIFICANT CAUSAL LINKS (CORRECTED)",
            "="*80,
            ""
        ]
        
        # Get significant links
        links = self.get_significant_links(use_corrected=True)
        
        if len(links) == 0:
            report_lines.append("No significant causal links found after correction.")
        else:
            report_lines.append(f"{'Source':<15} {'Target':<15} {'Lag':<5} {'p-value':<12} {'Test Stat':<12}")
            report_lines.append("-"*80)
            
            for link in links[:20]:  # Top 20
                report_lines.append(
                    f"{link['source']:<15} {link['target']:<15} {link['lag']:<5} "
                    f"{link['p_value']:<12.6f} {link['test_statistic']:<12.4f}"
                )
        
        report_lines.extend([
            "",
            "="*80,
            "RECOMMENDATION",
            "="*80,
            "",
            "Always use corrected p-values for causal inference to control",
            f"false discovery rate at {self.corrected_results['alpha_original']*100:.0f}% level.",
            "Uncorrected p-values inflate type I error rate with multiple testing.",
            ""
        ])
        
        report_str = "\n".join(report_lines)
        
        if output_path:
            with open(output_path, 'w') as f:
                f.write(report_str)
            logger.info(f"Report saved to {output_path}")
        
        return report_str


def main():
    """Demonstration of PCMCI with multiple testing correction"""
    logger.info("="*60)
    logger.info("PCMCI Multiple Testing Correction Demonstration")
    logger.info("="*60)
    
    if not TIGRAMITE_AVAILABLE:
        logger.error("Tigramite not available - cannot run demonstration")
        return
    
    # Generate synthetic time series with known causal structure
    np.random.seed(42)
    n_timesteps = 500
    n_vars = 5
    
    # X0 → X1 (lag 1), X1 → X2 (lag 2), X0 → X3 (lag 1)
    data = np.zeros((n_timesteps, n_vars))
    data[:, 0] = np.random.randn(n_timesteps)
    
    for t in range(2, n_timesteps):
        data[t, 1] = 0.7 * data[t-1, 0] + 0.3 * np.random.randn()
        data[t, 2] = 0.6 * data[t-2, 1] + 0.4 * np.random.randn()
        data[t, 3] = 0.5 * data[t-1, 0] + 0.5 * np.random.randn()
        data[t, 4] = np.random.randn()  # Independent
    
    df = pd.DataFrame(data, columns=[f'X{i}' for i in range(n_vars)])
    
    # Run PCMCI with different correction methods
    for correction_method in ['none', 'bonferroni', 'fdr_bh']:
        logger.info(f"\n{'='*60}")
        logger.info(f"Testing with correction method: {correction_method.upper()}")
        logger.info(f"{'='*60}")
        
        pcmci_wrapper = PCMCIWithMultipleTestingCorrection(
            dataframe=df,
            var_names=df.columns.tolist(),
            cond_ind_test='parcorr'
        )
        
        results = pcmci_wrapper.run_pcmci_with_correction(
            tau_max=5,
            pc_alpha=0.05,
            correction_method=correction_method,
            alpha_level=0.05
        )
        
        # Generate report
        report = pcmci_wrapper.generate_report()
        print(report)
    
    logger.info("\n" + "="*60)
    logger.info("PCMCI Correction Demonstration Complete")
    logger.info("="*60)


if __name__ == '__main__':
    main()
