import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats.anova import AnovaRM
from statsmodels.stats.multitest import fdrcorrection
import warnings


def compute_scale_comparison_stats(*data_arrays, rich_scale_idx, lazy_scale_idx, 
                                  alpha=0.05, correction_method='fdr'):
    """
    Compute statistical comparisons between rich and lazy scales for multiple dimensionality measures.
    
    Args:
        *data_arrays: Variable number of data arrays, each with shape (n_scales, n_seeds, n_timepoints)
        rich_scale_idx: Index of rich scale in the data arrays
        lazy_scale_idx: Index of lazy scale in the data arrays  
        alpha: Significance level for statistical tests
        correction_method: 'fdr' for FDR correction, 'none' for no correction
    
    Returns:
        Dictionary containing statistical results for each data array
    """
    
    if len(data_arrays) == 0:
        raise ValueError("At least one data array must be provided")
    
    # Get dimensions from first array
    n_scales, n_seeds, n_timepoints = data_arrays[0].shape
    
    # Validate that all arrays have the same shape
    for i, data in enumerate(data_arrays):
        if data.shape != (n_scales, n_seeds, n_timepoints):
            raise ValueError(f"Data array {i} has shape {data.shape}, expected {(n_scales, n_seeds, n_timepoints)}")
    
    print(f"Running statistical tests: Rich (idx={rich_scale_idx}) vs Lazy (idx={lazy_scale_idx})")
    print(f"Using paired t-tests with {correction_method} correction (α={alpha})")
    print(f"Testing {len(data_arrays)} dimensionality measure(s) across {n_timepoints} timepoints")
    
    # Store results for all data arrays
    all_results = {}
    
    for data_idx, data in enumerate(data_arrays):
        # Extract rich and lazy data
        data_rich = data[rich_scale_idx, :, :]  # shape: (n_seeds, n_timepoints)
        data_lazy = data[lazy_scale_idx, :, :]  # shape: (n_seeds, n_timepoints)
        
        # Initialize arrays for storing test results
        p_values = []
        effect_sizes = []
        test_statistics = []
        
        # Run paired t-tests for each timepoint
        for t in range(n_timepoints):
            rich_vals = data_rich[:, t]  # 10 seeds
            lazy_vals = data_lazy[:, t]  # 10 seeds
            
            # Paired t-test
            stat, p_val = stats.ttest_rel(rich_vals, lazy_vals)
            
            # Calculate Cohen's d for paired samples
            differences = rich_vals - lazy_vals
            cohens_d = np.mean(differences) / np.std(differences, ddof=1) if np.std(differences, ddof=1) != 0 else 0
            
            p_values.append(p_val)
            effect_sizes.append(cohens_d)
            test_statistics.append(stat)
        
        # Apply correction
        if correction_method == 'fdr':
            rejected, corrected_p = fdrcorrection(p_values, alpha=alpha)
        else:  # no correction
            rejected = np.array(p_values) < alpha
            corrected_p = np.array(p_values)
        
        # Store results for this data array
        all_results[f'data_{data_idx}'] = {
            'p_values': p_values,
            'corrected_p': corrected_p,
            'rejected': rejected,
            'effect_sizes': effect_sizes,
            'test_statistics': test_statistics,
            'n_significant': np.sum(rejected)
        }
        
        print(f"  Data {data_idx}: {np.sum(rejected)}/{n_timepoints} significant timepoints")
    
    # Add metadata
    all_results['metadata'] = {
        'n_data_arrays': len(data_arrays),
        'n_timepoints': n_timepoints,
        'n_seeds': n_seeds,
        'rich_scale_idx': rich_scale_idx,
        'lazy_scale_idx': lazy_scale_idx,
        'correction_method': correction_method,
        'alpha': alpha
    }
    
    return all_results

def compute_within_scale_comparison_stats(*data_arrays, rich_scale_idx, lazy_scale_idx, 
                                         alpha=0.05, correction_method='fdr'):
    """
    Compute statistical comparisons between different measures within each scale.
    
    Args:
        *data_arrays: Variable number of data arrays, each with shape (n_scales, n_seeds, n_timepoints)
        rich_scale_idx: Index of rich scale in the data arrays
        lazy_scale_idx: Index of lazy scale in the data arrays  
        alpha: Significance level for statistical tests
        correction_method: 'fdr' for FDR correction, 'none' for no correction
    
    Returns:
        Dictionary containing statistical results for rich and lazy scales separately
    """
    
    if len(data_arrays) < 2:
        raise ValueError("At least two data arrays must be provided for within-scale comparisons")
    
    # Get dimensions from first array
    n_scales, n_seeds, n_timepoints = data_arrays[0].shape
    
    print(f"Running within-scale statistical tests between {len(data_arrays)} measures")
    print(f"Rich scale (idx={rich_scale_idx}) and Lazy scale (idx={lazy_scale_idx})")
    print(f"Using paired t-tests with {correction_method} correction (α={alpha})")
    
    results = {'rich': {}, 'lazy': {}}
    
    # Compare all pairs of measures within each scale
    for scale_name, scale_idx in [('rich', rich_scale_idx), ('lazy', lazy_scale_idx)]:
        results[scale_name]['pairwise_comparisons'] = {}
        all_p_values = []
        
        for i in range(len(data_arrays)):
            for j in range(i+1, len(data_arrays)):
                comparison_name = f'measure_{i}_vs_measure_{j}'
                
                # Extract data for this scale
                data_i = data_arrays[i][scale_idx, :, :]  # (n_seeds, n_timepoints)
                data_j = data_arrays[j][scale_idx, :, :]  # (n_seeds, n_timepoints)
                
                timepoint_results = []
                timepoint_p_values = []
                
                for t in range(n_timepoints):
                    # Paired t-test at this timepoint
                    t_stat, p_val = stats.ttest_rel(data_i[:, t], data_j[:, t])
                    timepoint_results.append({'t_stat': t_stat, 'p_value': p_val})
                    timepoint_p_values.append(p_val)
                
                # Store uncorrected results
                results[scale_name]['pairwise_comparisons'][comparison_name] = {
                    'timepoint_results': timepoint_results,
                    'uncorrected_p_values': timepoint_p_values
                }
                
                all_p_values.extend(timepoint_p_values)
        
        # Apply multiple comparisons correction across all tests for this scale
        if correction_method == 'fdr' and all_p_values:
            rejected, corrected_p_values = fdrcorrection(all_p_values, alpha=alpha)
            
            # Redistribute corrected p-values back to comparisons
            p_idx = 0
            for comparison_name in results[scale_name]['pairwise_comparisons']:
                n_timepoints_comp = len(results[scale_name]['pairwise_comparisons'][comparison_name]['timepoint_results'])
                corrected_subset = corrected_p_values[p_idx:p_idx + n_timepoints_comp]
                rejected_subset = rejected[p_idx:p_idx + n_timepoints_comp]
                
                results[scale_name]['pairwise_comparisons'][comparison_name]['corrected_p_values'] = corrected_subset.tolist()
                results[scale_name]['pairwise_comparisons'][comparison_name]['significant'] = rejected_subset.tolist()
                
                p_idx += n_timepoints_comp
    
    return results



def run_multi_dimensionality_anova_analysis(*data_arrays,
                                           rich_scale_idx, lazy_scale_idx,
                                           measure_names=None,
                                           analysis_type='proximity',
                                           include_posthocs=False):
    """
    Run ANOVA analysis comparing rich vs lazy initialization across multiple dimensionality measures.
    
    Args:
        *data_arrays: 1-3 data arrays, each with shape (n_scales, n_seeds, n_timepoints)
        rich_scale_idx: Index of rich scale in the data arrays
        lazy_scale_idx: Index of lazy scale in the data arrays
        measure_names: List of names for each measure (e.g., ['Task', 'Stimulus', 'Motor'])
        analysis_type: Type of analysis to run:
            - 'proximity': Two ANOVAs comparing Stim1-Dly1 and Stim2-Dly2 periods (2-3 measures)
            - 'proximity_single': Two ANOVAs for single measure (2 initialization × 2 periods)
            - 'rich_epochs': Rich-only ANOVA (2 epochs × 2 periods) for single measure
            - 'phases': Single ANOVA across 6 task phases 
            - 'full_temporal': Single ANOVA across all 10 timepoints
        include_posthocs: Whether to include post-hoc tests
    
    Returns:
        Dictionary containing ANOVA results and summary statistics
    """
    
    # Validate inputs
    n_measures = len(data_arrays)
    if n_measures < 1 or n_measures > 3:
        raise ValueError("Function supports 1 to 3 dimensionality measures")
    
    # Set default measure names if not provided
    if measure_names is None:
        measure_names = [f'Measure_{i+1}' for i in range(n_measures)]
    elif len(measure_names) != n_measures:
        raise ValueError("measure_names must have same length as number of data arrays")
    
    # Validate data shapes
    reference_shape = data_arrays[0].shape
    n_scales, n_seeds, n_timepoints = reference_shape
    
    for i, data in enumerate(data_arrays):
        if data.shape != reference_shape:
            raise ValueError(f"Data array {i} has shape {data.shape}, expected {reference_shape}")
    
    if n_timepoints != 10:
        warnings.warn(f"Expected 10 timepoints, got {n_timepoints}. Timepoint mapping may be incorrect.")
    
    # Auto-detect analysis type based on number of measures if needed
    if analysis_type == 'proximity' and n_measures == 1:
        analysis_type = 'proximity_single'
        print(f"Auto-switching to 'proximity_single' for single measure analysis")
    
    print(f"Running {analysis_type} ANOVA analysis")
    print(f"Measures: {', '.join(measure_names)}")
    print(f"Rich scale index: {rich_scale_idx}, Lazy scale index: {lazy_scale_idx}")
    print(f"Data shape per measure: {reference_shape}")
    
    # Define timepoint mappings
    timepoint_labels = ['Rule', 'Stim1_e', 'Stim1_l', 'Dly1_e', 'Dly1_l', 
                       'Stim2_e', 'Stim2_l', 'Dly2_e', 'Dly2_l', 'Resp']
    
    def prepare_data_for_analysis(analysis_type):
        """Prepare data based on analysis type."""
        
        if analysis_type == 'proximity':
            # Prepare data for two separate ANOVAs (2-3 measures)
            return prepare_proximity_data()
        elif analysis_type == 'proximity_single':
            # Prepare data for two separate ANOVAs (single measure)
            return prepare_proximity_single_data()
        elif analysis_type == 'rich_epochs':
            # Prepare data for rich-only epochs analysis
            return prepare_rich_epochs_data()
        elif analysis_type == 'phases':
            # Prepare data for 6-phase ANOVA
            return prepare_phases_data()
        elif analysis_type == 'full_temporal':
            # Prepare data for full 10-timepoint ANOVA
            return prepare_full_temporal_data()
        else:
            raise ValueError(f"Unknown analysis_type: {analysis_type}")
    
    def prepare_proximity_data():
        """Prepare data for Stim1-Dly1 and Stim2-Dly2 comparisons (multi-measure)."""
        # Timepoint indices for each comparison
        stim1_indices = [1, 2]  # Stim1_e, Stim1_l
        dly1_indices = [3, 4]   # Dly1_e, Dly1_l
        stim2_indices = [5, 6]  # Stim2_e, Stim2_l
        dly2_indices = [7, 8]   # Dly2_e, Dly2_l
        
        proximity_data = {}
        
        # First comparison: Stim1 vs Dly1
        comparison_name = 'stim1_dly1'
        data_list = []
        
        for measure_idx, data_array in enumerate(data_arrays):
            measure_name = measure_names[measure_idx]
            
            # Extract rich and lazy data
            rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
            lazy_data = data_array[lazy_scale_idx, :, :]  # (n_seeds, n_timepoints)
            
            # Average within periods
            rich_stim1 = np.mean(rich_data[:, stim1_indices], axis=1)  # (n_seeds,)
            rich_dly1 = np.mean(rich_data[:, dly1_indices], axis=1)    # (n_seeds,)
            lazy_stim1 = np.mean(lazy_data[:, stim1_indices], axis=1)  # (n_seeds,)
            lazy_dly1 = np.mean(lazy_data[:, dly1_indices], axis=1)    # (n_seeds,)
            
            for seed_idx in range(n_seeds):
                # Rich initialization
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Rich',
                    'period': 'Stim1',
                    'value': rich_stim1[seed_idx]
                })
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Rich', 
                    'period': 'Dly1',
                    'value': rich_dly1[seed_idx]
                })
                
                # Lazy initialization
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Lazy',
                    'period': 'Stim1',
                    'value': lazy_stim1[seed_idx]
                })
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Lazy',
                    'period': 'Dly1',
                    'value': lazy_dly1[seed_idx]
                })
        
        proximity_data[comparison_name] = pd.DataFrame(data_list)
        
        # Second comparison: Stim2 vs Dly2
        comparison_name = 'stim2_dly2'
        data_list = []
        
        for measure_idx, data_array in enumerate(data_arrays):
            measure_name = measure_names[measure_idx]
            
            # Extract rich and lazy data
            rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
            lazy_data = data_array[lazy_scale_idx, :, :]  # (n_seeds, n_timepoints)
            
            # Average within periods
            rich_stim2 = np.mean(rich_data[:, stim2_indices], axis=1)  # (n_seeds,)
            rich_dly2 = np.mean(rich_data[:, dly2_indices], axis=1)    # (n_seeds,)
            lazy_stim2 = np.mean(lazy_data[:, stim2_indices], axis=1)  # (n_seeds,)
            lazy_dly2 = np.mean(lazy_data[:, dly2_indices], axis=1)    # (n_seeds,)
            
            for seed_idx in range(n_seeds):
                # Rich initialization
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Rich',
                    'period': 'Stim2',
                    'value': rich_stim2[seed_idx]
                })
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Rich',
                    'period': 'Dly2',
                    'value': rich_dly2[seed_idx]
                })
                
                # Lazy initialization
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Lazy',
                    'period': 'Stim2',
                    'value': lazy_stim2[seed_idx]
                })
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'measure': measure_name,
                    'initialization': 'Lazy',
                    'period': 'Dly2',
                    'value': lazy_dly2[seed_idx]
                })
        
        proximity_data[comparison_name] = pd.DataFrame(data_list)
        
        return proximity_data
    
    def prepare_proximity_single_data():
        """Prepare data for single-measure proximity analysis (2 initialization × 2 periods)."""
        if n_measures != 1:
            raise ValueError("proximity_single analysis requires exactly 1 measure")
        
        # Timepoint indices for each comparison
        stim1_indices = [1, 2]  # Stim1_e, Stim1_l
        dly1_indices = [3, 4]   # Dly1_e, Dly1_l
        stim2_indices = [5, 6]  # Stim2_e, Stim2_l
        dly2_indices = [7, 8]   # Dly2_e, Dly2_l
        
        proximity_data = {}
        data_array = data_arrays[0]  # Single measure
        
        # Extract rich and lazy data
        rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
        lazy_data = data_array[lazy_scale_idx, :, :]  # (n_seeds, n_timepoints)
        
        # First comparison: Stim1 vs Dly1
        comparison_name = 'stim1_dly1'
        data_list = []
        
        # Average within periods
        rich_stim1 = np.mean(rich_data[:, stim1_indices], axis=1)  # (n_seeds,)
        rich_dly1 = np.mean(rich_data[:, dly1_indices], axis=1)    # (n_seeds,)
        lazy_stim1 = np.mean(lazy_data[:, stim1_indices], axis=1)  # (n_seeds,)
        lazy_dly1 = np.mean(lazy_data[:, dly1_indices], axis=1)    # (n_seeds,)
        
        for seed_idx in range(n_seeds):
            # Rich initialization
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Rich',
                'period': 'Stim',
                'value': rich_stim1[seed_idx]
            })
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Rich', 
                'period': 'Dly',
                'value': rich_dly1[seed_idx]
            })
            
            # Lazy initialization
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Lazy',
                'period': 'Stim',
                'value': lazy_stim1[seed_idx]
            })
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Lazy',
                'period': 'Dly',
                'value': lazy_dly1[seed_idx]
            })
        
        proximity_data[comparison_name] = pd.DataFrame(data_list)
        
        # Second comparison: Stim2 vs Dly2
        comparison_name = 'stim2_dly2'
        data_list = []
        
        # Average within periods
        rich_stim2 = np.mean(rich_data[:, stim2_indices], axis=1)  # (n_seeds,)
        rich_dly2 = np.mean(rich_data[:, dly2_indices], axis=1)    # (n_seeds,)
        lazy_stim2 = np.mean(lazy_data[:, stim2_indices], axis=1)  # (n_seeds,)
        lazy_dly2 = np.mean(lazy_data[:, dly2_indices], axis=1)    # (n_seeds,)
        
        for seed_idx in range(n_seeds):
            # Rich initialization
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Rich',
                'period': 'Stim',
                'value': rich_stim2[seed_idx]
            })
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Rich',
                'period': 'Dly',
                'value': rich_dly2[seed_idx]
            })
            
            # Lazy initialization
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Lazy',
                'period': 'Stim',
                'value': lazy_stim2[seed_idx]
            })
            data_list.append({
                'subject': f'seed_{seed_idx}',
                'initialization': 'Lazy',
                'period': 'Dly',
                'value': lazy_dly2[seed_idx]
            })
        
        proximity_data[comparison_name] = pd.DataFrame(data_list)
        
        return proximity_data
    
    def prepare_rich_epochs_data():
        """Prepare data for rich-only epochs analysis (2 epochs × 2 periods)."""
        # Timepoint indices for rich epochs analysis
        stim1_indices = [1, 2]  # Stim1_e, Stim1_l
        dly1_indices = [3, 4]   # Dly1_e, Dly1_l
        stim2_indices = [5, 6]  # Stim2_e, Stim2_l
        dly2_indices = [7, 8]   # Dly2_e, Dly2_l
        
        if n_measures == 1:
            # Single measure: 2 epochs × 2 periods
            data_array = data_arrays[0]
            rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
            
            data_list = []
            
            # Average within periods
            rich_stim1 = np.mean(rich_data[:, stim1_indices], axis=1)  # Early epoch, Stim period
            rich_dly1 = np.mean(rich_data[:, dly1_indices], axis=1)    # Early epoch, Dly period
            rich_stim2 = np.mean(rich_data[:, stim2_indices], axis=1)  # Late epoch, Stim period
            rich_dly2 = np.mean(rich_data[:, dly2_indices], axis=1)    # Late epoch, Dly period
            
            for seed_idx in range(n_seeds):
                # Early epoch
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'epoch': 'Early',
                    'period': 'Stim',
                    'value': rich_stim1[seed_idx]
                })
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'epoch': 'Early',
                    'period': 'Dly',
                    'value': rich_dly1[seed_idx]
                })
                
                # Late epoch
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'epoch': 'Late',
                    'period': 'Stim',
                    'value': rich_stim2[seed_idx]
                })
                data_list.append({
                    'subject': f'seed_{seed_idx}',
                    'epoch': 'Late',
                    'period': 'Dly',
                    'value': rich_dly2[seed_idx]
                })
            
            return pd.DataFrame(data_list)
        
        else:
            # Multiple measures: 2-3 measures × 2 epochs × 2 periods
            data_list = []
            
            for measure_idx, data_array in enumerate(data_arrays):
                measure_name = measure_names[measure_idx]
                rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
                
                # Average within periods
                rich_stim1 = np.mean(rich_data[:, stim1_indices], axis=1)  # Early epoch, Stim period
                rich_dly1 = np.mean(rich_data[:, dly1_indices], axis=1)    # Early epoch, Dly period
                rich_stim2 = np.mean(rich_data[:, stim2_indices], axis=1)  # Late epoch, Stim period
                rich_dly2 = np.mean(rich_data[:, dly2_indices], axis=1)    # Late epoch, Dly period
                
                for seed_idx in range(n_seeds):
                    # Early epoch
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'epoch': 'Early',
                        'period': 'Stim',
                        'value': rich_stim1[seed_idx]
                    })
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'epoch': 'Early',
                        'period': 'Dly',
                        'value': rich_dly1[seed_idx]
                    })
                    
                    # Late epoch
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'epoch': 'Late',
                        'period': 'Stim',
                        'value': rich_stim2[seed_idx]
                    })
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'epoch': 'Late',
                        'period': 'Dly',
                        'value': rich_dly2[seed_idx]
                    })
            
            return pd.DataFrame(data_list)
    
    def prepare_phases_data():
        """Prepare data for 6-phase ANOVA."""
        # Define 6 phases by averaging relevant timepoints
        phase_definitions = {
            'Rule': [0],           # Rule
            'Stim1': [1, 2],       # Stim1_e, Stim1_l
            'Dly1': [3, 4],        # Dly1_e, Dly1_l
            'Stim2': [5, 6],       # Stim2_e, Stim2_l
            'Dly2': [7, 8],        # Dly2_e, Dly2_l
            'Resp': [9]            # Resp
        }
        
        data_list = []
        
        for measure_idx, data_array in enumerate(data_arrays):
            measure_name = measure_names[measure_idx]
            
            # Extract rich and lazy data
            rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
            lazy_data = data_array[lazy_scale_idx, :, :]  # (n_seeds, n_timepoints)
            
            for phase_name, timepoint_indices in phase_definitions.items():
                # Average within phase
                rich_phase = np.mean(rich_data[:, timepoint_indices], axis=1)  # (n_seeds,)
                lazy_phase = np.mean(lazy_data[:, timepoint_indices], axis=1)  # (n_seeds,)
                
                for seed_idx in range(n_seeds):
                    # Rich initialization
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'initialization': 'Rich',
                        'phase': phase_name,
                        'value': rich_phase[seed_idx]
                    })
                    
                    # Lazy initialization
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'initialization': 'Lazy',
                        'phase': phase_name,
                        'value': lazy_phase[seed_idx]
                    })
        
        return pd.DataFrame(data_list)
    
    def prepare_full_temporal_data():
        """Prepare data for full 10-timepoint ANOVA."""
        data_list = []
        
        for measure_idx, data_array in enumerate(data_arrays):
            measure_name = measure_names[measure_idx]
            
            # Extract rich and lazy data
            rich_data = data_array[rich_scale_idx, :, :]  # (n_seeds, n_timepoints)
            lazy_data = data_array[lazy_scale_idx, :, :]  # (n_seeds, n_timepoints)
            
            for timepoint_idx in range(n_timepoints):
                timepoint_name = timepoint_labels[timepoint_idx]
                
                for seed_idx in range(n_seeds):
                    # Rich initialization
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'initialization': 'Rich',
                        'timepoint': timepoint_name,
                        'value': rich_data[seed_idx, timepoint_idx]
                    })
                    
                    # Lazy initialization
                    data_list.append({
                        'subject': f'seed_{seed_idx}',
                        'measure': measure_name,
                        'initialization': 'Lazy',
                        'timepoint': timepoint_name,
                        'value': lazy_data[seed_idx, timepoint_idx]
                    })
        
        return pd.DataFrame(data_list)
    
    def run_anova(df, within_factors, dependent_var='value'):
        """Run repeated measures ANOVA."""
        try:
            # Run repeated measures ANOVA
            anova_result = AnovaRM(df, dependent_var, 'subject', within_factors).fit()
            return anova_result
        except Exception as e:
            print(f"ANOVA failed: {e}")
            return None
    
    def extract_anova_results(anova_result, analysis_name=""):
        """Extract key results from ANOVA output."""
        if anova_result is None:
            return None
        
        results = {
            'analysis_name': analysis_name,
            'anova_table': anova_result.anova_table,
            'main_effects': {},
            'interactions': {}
        }
        
        # Extract main effects and interactions
        for effect_name in anova_result.anova_table.index:
            f_stat = anova_result.anova_table.loc[effect_name, 'F Value']
            p_value = anova_result.anova_table.loc[effect_name, 'Pr > F']
            
            if ':' in effect_name:
                # This is an interaction
                results['interactions'][effect_name] = {
                    'F': f_stat,
                    'p': p_value,
                    'significant': p_value < 0.05
                }
            else:
                # This is a main effect
                results['main_effects'][effect_name] = {
                    'F': f_stat,
                    'p': p_value,
                    'significant': p_value < 0.05
                }
        
        return results
    
    def run_posthoc_tests(df, analysis_type):
        """Run post-hoc paired t-tests for proximity_single analysis with FDR correction."""
        if analysis_type != 'proximity_single' or not include_posthocs:
            return {}
        
        posthoc_results = {}
        all_p_values = []
        test_info = []  # Store info about each test for redistribution
        
        # For each comparison (stim1_dly1 and stim2_dly2)
        for comparison in ['stim1_dly1', 'stim2_dly2']:
            if comparison in df:
                comp_df = df[comparison]
                comparison_results = {}
                
                # Get rich and lazy data for both periods
                rich_stim = comp_df[(comp_df['initialization'] == 'Rich') & 
                                   (comp_df['period'] == 'Stim')]['value'].values
                rich_dly = comp_df[(comp_df['initialization'] == 'Rich') & 
                                  (comp_df['period'] == 'Dly')]['value'].values
                lazy_stim = comp_df[(comp_df['initialization'] == 'Lazy') & 
                                   (comp_df['period'] == 'Stim')]['value'].values
                lazy_dly = comp_df[(comp_df['initialization'] == 'Lazy') & 
                                  (comp_df['period'] == 'Dly')]['value'].values
                
                # Paired t-test: Rich Stim vs Rich Dly
                t_stat, p_val = stats.ttest_rel(rich_stim, rich_dly)
                comparison_results['rich_stim_vs_dly'] = {
                    't_stat': t_stat,
                    'p_value': p_val,
                    'p_uncorrected': p_val
                }
                all_p_values.append(p_val)
                test_info.append((comparison, 'rich_stim_vs_dly'))
                
                # Paired t-test: Lazy Stim vs Lazy Dly
                t_stat, p_val = stats.ttest_rel(lazy_stim, lazy_dly)
                comparison_results['lazy_stim_vs_dly'] = {
                    't_stat': t_stat,
                    'p_value': p_val,
                    'p_uncorrected': p_val
                }
                all_p_values.append(p_val)
                test_info.append((comparison, 'lazy_stim_vs_dly'))
                
                posthoc_results[comparison] = comparison_results
        
        # Apply FDR correction across all post-hoc tests
        if all_p_values:
            rejected, corrected_p_values = fdrcorrection(all_p_values, alpha=0.05)
            
            # Redistribute corrected p-values back to results
            for i, (comparison, test_name) in enumerate(test_info):
                posthoc_results[comparison][test_name]['p_corrected'] = corrected_p_values[i]
                posthoc_results[comparison][test_name]['significant'] = rejected[i]
                posthoc_results[comparison][test_name]['significant_uncorrected'] = all_p_values[i] < 0.05
        
        return posthoc_results
    
    # Main analysis execution
    try:
        prepared_data = prepare_data_for_analysis(analysis_type)
        results = {'analysis_type': analysis_type, 'measure_names': measure_names}
        
        if analysis_type in ['proximity', 'proximity_single']:
            # Run two separate ANOVAs
            if analysis_type == 'proximity':
                within_factors = ['measure', 'initialization', 'period']
            else:  # proximity_single
                within_factors = ['initialization', 'period']
            
            print("\nRunning Stim1-Dly1 proximity ANOVA...")
            stim1_dly1_anova = run_anova(
                prepared_data['stim1_dly1'], 
                within_factors
            )
            results['stim1_dly1'] = extract_anova_results(stim1_dly1_anova, "Stim1-Dly1")
            
            print("Running Stim2-Dly2 proximity ANOVA...")
            stim2_dly2_anova = run_anova(
                prepared_data['stim2_dly2'], 
                within_factors
            )
            results['stim2_dly2'] = extract_anova_results(stim2_dly2_anova, "Stim2-Dly2")
            
            # Run post-hoc tests if requested
            if include_posthocs:
                results['posthoc_tests'] = run_posthoc_tests(prepared_data, analysis_type)
            
        elif analysis_type == 'rich_epochs':
            # Run single rich epochs ANOVA
            print("\nRunning rich epochs ANOVA...")
            if n_measures == 1:
                within_factors = ['epoch', 'period']
            else:
                within_factors = ['measure', 'epoch', 'period']
            
            epochs_anova = run_anova(prepared_data, within_factors)
            results['rich_epochs'] = extract_anova_results(epochs_anova, "Rich Epochs")
            
        elif analysis_type == 'phases':
            # Run single 6-phase ANOVA
            print("\nRunning 6-phase ANOVA...")
            phases_anova = run_anova(prepared_data, ['measure', 'initialization', 'phase'])
            results['phases'] = extract_anova_results(phases_anova, "6-Phase")
            
        elif analysis_type == 'full_temporal':
            # Run single 10-timepoint ANOVA
            print("\nRunning full temporal ANOVA...")
            temporal_anova = run_anova(prepared_data, ['measure', 'initialization', 'timepoint'])
            results['full_temporal'] = extract_anova_results(temporal_anova, "Full Temporal")
        
        # Print summary
        print_anova_summary(results)
        
        return results
        
    except Exception as e:
        print(f"Analysis failed: {e}")
        return None


def print_anova_summary(results):
    """Print a summary of ANOVA results."""
    print(f"\n{'='*60}")
    print(f"ANOVA SUMMARY - {results['analysis_type'].upper()} ANALYSIS")
    print(f"{'='*60}")
    print(f"Measures: {', '.join(results['measure_names'])}")
    
    # Determine ANOVA structure based on analysis type
    n_measures = len(results['measure_names'])
    
    def get_anova_structure(analysis_type):
        """Return the factorial design structure as a string."""
        if analysis_type == 'proximity':
            return f"{n_measures} (Measure) × 2 (Initialization) × 2 (Period)"
        elif analysis_type == 'proximity_single':
            return "2 (Initialization) × 2 (Period)"
        elif analysis_type == 'rich_epochs':
            if n_measures == 1:
                return "2 (Epoch) × 2 (Period)"
            else:
                return f"{n_measures} (Measure) × 2 (Epoch) × 2 (Period)"
        elif analysis_type == 'phases':
            return f"{n_measures} (Measure) × 2 (Initialization) × 6 (Phase)"
        elif analysis_type == 'full_temporal':
            return f"{n_measures} (Measure) × 2 (Initialization) × 10 (Timepoint)"
        else:
            return "Unknown structure"
    
    # Function to print effects nicely
    def print_effects(effects_dict, effect_type):
        if not effects_dict:
            print(f"  No {effect_type} found")
            return
            
        print(f"  {effect_type.upper()}:")
        for effect_name, stats in effects_dict.items():
            sig_marker = "***" if stats['p'] < 0.001 else "**" if stats['p'] < 0.01 else "*" if stats['p'] < 0.05 else ""
            print(f"    {effect_name}: F={stats['F']:.3f}, p={stats['p']:.4f} {sig_marker}")
    
    # Function to print post-hoc results
    def print_posthoc_results(posthoc_dict):
        if not posthoc_dict:
            return
            
        print(f"\n  POST-HOC PAIRED T-TESTS (FDR corrected):")
        for comparison, tests in posthoc_dict.items():
            comparison_name = comparison.upper().replace('_', '-')
            print(f"    {comparison_name}:")
            for test_name, stats in tests.items():
                test_label = test_name.replace('_', ' ').title()
                
                # Use corrected p-value for significance
                p_corrected = stats.get('p_corrected', stats['p_value'])
                sig_marker = "***" if p_corrected < 0.001 else "**" if p_corrected < 0.01 else "*" if p_corrected < 0.05 else ""
                
                # Show both uncorrected and corrected p-values
                if 'p_corrected' in stats:
                    print(f"      {test_label}: t={stats['t_stat']:.3f}, p_uncorr={stats['p_uncorrected']:.4f}, p_corr={p_corrected:.4f} {sig_marker}")
                else:
                    print(f"      {test_label}: t={stats['t_stat']:.3f}, p={stats['p_value']:.4f} {sig_marker}")
    
    if results['analysis_type'] in ['proximity', 'proximity_single']:
        # Two separate ANOVAs
        anova_structure = get_anova_structure(results['analysis_type'])
        print(f"Design: {anova_structure} (repeated measures)")
        
        for comparison in ['stim1_dly1', 'stim2_dly2']:
            if comparison in results and results[comparison]:
                comparison_name = comparison.upper().replace('_', '-')
                print(f"\n{comparison_name} COMPARISON:")
                print(f"  Structure: {anova_structure}")
                print_effects(results[comparison]['main_effects'], 'Main Effects')
                print_effects(results[comparison]['interactions'], 'Interactions')
        
        # Print post-hoc results if available
        if 'posthoc_tests' in results:
            print_posthoc_results(results['posthoc_tests'])
            
    elif results['analysis_type'] == 'rich_epochs':
        # Single rich epochs ANOVA
        anova_structure = get_anova_structure(results['analysis_type'])
        print(f"Design: {anova_structure} (repeated measures)")
        
        if 'rich_epochs' in results and results['rich_epochs']:
            print(f"\nRICH EPOCHS ANALYSIS:")
            print(f"  Structure: {anova_structure}")
            print_effects(results['rich_epochs']['main_effects'], 'Main Effects')
            print_effects(results['rich_epochs']['interactions'], 'Interactions')
    
    else:
        # Single ANOVA (phases or full_temporal)
        analysis_key = 'phases' if results['analysis_type'] == 'phases' else 'full_temporal'
        anova_structure = get_anova_structure(results['analysis_type'])
        print(f"Design: {anova_structure} (repeated measures)")
        
        if analysis_key in results and results[analysis_key]:
            analysis_name = analysis_key.replace('_', ' ').title()
            print(f"\n{analysis_name.upper()} ANALYSIS:")
            print(f"  Structure: {anova_structure}")
            print_effects(results[analysis_key]['main_effects'], 'Main Effects')
            print_effects(results[analysis_key]['interactions'], 'Interactions')
    
    print(f"\n{'*' * 60}")
    print("Significance levels: * p<0.05, ** p<0.01, *** p<0.001")
    print(f"{'*' * 60}")


# Example usage functions
def example_usage_proximity_single():
    """Example of single-measure proximity analysis."""
    # Simulate some data for demonstration
    np.random.seed(42)
    n_scales, n_seeds, n_timepoints = 5, 10, 10
    
    # Create synthetic data array for single measure
    task_data = np.random.randn(n_scales, n_seeds, n_timepoints)
    
    # Run proximity_single analysis
    results = run_multi_dimensionality_anova_analysis(
        task_data,
        rich_scale_idx=0, lazy_scale_idx=4,
        measure_names=['Task'],
        analysis_type='proximity_single',
        include_posthocs=True
    )
    
    return results


def example_usage_rich_epochs():
    """Example of rich epochs analysis."""
    # Simulate some data for demonstration
    np.random.seed(42)
    n_scales, n_seeds, n_timepoints = 5, 10, 10
    
    # Single measure rich epochs
    task_data = np.random.randn(n_scales, n_seeds, n_timepoints)
    
    results_single = run_multi_dimensionality_anova_analysis(
        task_data,
        rich_scale_idx=0, lazy_scale_idx=4,
        measure_names=['Task'],
        analysis_type='rich_epochs'
    )
    
    # Multi-measure rich epochs
    stimulus_data = np.random.randn(n_scales, n_seeds, n_timepoints)
    motor_data = np.random.randn(n_scales, n_seeds, n_timepoints)
    
    results_multi = run_multi_dimensionality_anova_analysis(
        task_data, stimulus_data, motor_data,
        rich_scale_idx=0, lazy_scale_idx=4,
        measure_names=['Task', 'Stimulus', 'Motor'],
        analysis_type='rich_epochs'
    )
    
    return results_single, results_multi


def example_usage_original():
    """Example of original proximity analysis with multiple measures."""
    # Simulate some data for demonstration
    np.random.seed(42)
    n_scales, n_seeds, n_timepoints = 5, 10, 10
    
    # Create synthetic data arrays
    task_data = np.random.randn(n_scales, n_seeds, n_timepoints)
    stimulus_data = np.random.randn(n_scales, n_seeds, n_timepoints) 
    motor_data = np.random.randn(n_scales, n_seeds, n_timepoints)
    
    # Run proximity analysis
    results = run_multi_dimensionality_anova_analysis(
        task_data, stimulus_data, motor_data,
        rich_scale_idx=0, lazy_scale_idx=4,
        measure_names=['Task', 'Stimulus', 'Motor'],
        analysis_type='proximity'
    )
    
    return results


# Uncomment to test different analysis types:
# results_single = example_usage_proximity_single()
# results_epochs_single, results_epochs_multi = example_usage_rich_epochs()
# results_original = example_usage_original()


import numpy as np
import pandas as pd
from statsmodels.stats.anova import AnovaRM
from scipy import stats
from itertools import combinations

def two_way_repeated_measures_anova(data, factor1_name='Factor1', factor2_name='Factor2', 
                                   factor1_levels=None, factor2_levels=None, 
                                   posthoc=False, correction='bonferroni'):
    """
    Perform 2-way repeated measures ANOVA on 3D array of shape (n_factor1, n_samples, n_factor2)
    using statsmodels AnovaRM.
    
    Parameters:
    -----------
    data : numpy.ndarray
        3D array of shape (n_factor1, n_samples, n_factor2)
        - n_factor1: number of levels for factor 1
        - n_samples: number of subjects (repeated measures)
        - n_factor2: number of levels for factor 2
    factor1_name : str
        Name of the first factor (default: 'Factor1')
    factor2_name : str
        Name of the second factor (default: 'Factor2')
    factor1_levels : list or None
        Names for levels of factor1. If None, uses ['F1_L0', 'F1_L1', ...]
    factor2_levels : list or None
        Names for levels of factor2. If None, uses ['F2_L0', 'F2_L1', ...]
    
    Returns:
    --------
    dict : Dictionary containing ANOVA results and descriptive statistics
    """
    
    n_factor1, n_samples, n_factor2 = data.shape
    
    # Create default level names if not provided
    if factor1_levels is None:
        factor1_levels = [f'L{i}' for i in range(n_factor1)]
    if factor2_levels is None:
        factor2_levels = [f'L{i}' for i in range(n_factor2)]
    
    # Reshape data to long format for repeated measures
    observations = []
    factor1_labels = []
    factor2_labels = []
    subject_ids = []
    
    for i in range(n_factor1):
        for j in range(n_factor2):
            for k in range(n_samples):
                observations.append(data[i, k, j])
                factor1_labels.append(factor1_levels[i])
                factor2_labels.append(factor2_levels[j])
                subject_ids.append(k)  # Subject ID
    
    # Create DataFrame for repeated measures
    df = pd.DataFrame({
        'value': observations,
        'subject': subject_ids,
        factor1_name: factor1_labels,
        factor2_name: factor2_labels
    })
    
    # Ensure factors are properly typed
    df['subject'] = df['subject'].astype('category')
    df[factor1_name] = df[factor1_name].astype('category')
    df[factor2_name] = df[factor2_name].astype('category')
    
    # Perform repeated measures ANOVA
    # Note: AnovaRM expects within-subject factors to be specified
    aovrm = AnovaRM(df, 'value', 'subject', within=[factor1_name, factor2_name])
    anova_results = aovrm.fit()
    
    # Calculate descriptive statistics
    grand_mean = df['value'].mean()
    factor1_means = df.groupby(factor1_name, observed=True)['value'].mean()
    factor2_means = df.groupby(factor2_name, observed=True)['value'].mean()
    interaction_means = df.groupby([factor1_name, factor2_name], observed=True)['value'].mean()
    
    # Get the ANOVA table
    anova_table = anova_results.anova_table.copy()
    
    # Extract results for easy access
    summary_results = {}
    
    # Main effect of factor 1
    if factor1_name in anova_table.index:
        summary_results[f'{factor1_name} main effect'] = {
            'F': anova_table.loc[factor1_name, 'F Value'],
            'p': anova_table.loc[factor1_name, 'Pr > F'],
            'df': (anova_table.loc[factor1_name, 'Num DF'], anova_table.loc[factor1_name, 'Den DF'])
        }
    
    # Main effect of factor 2
    if factor2_name in anova_table.index:
        summary_results[f'{factor2_name} main effect'] = {
            'F': anova_table.loc[factor2_name, 'F Value'],
            'p': anova_table.loc[factor2_name, 'Pr > F'],
            'df': (anova_table.loc[factor2_name, 'Num DF'], anova_table.loc[factor2_name, 'Den DF'])
        }
    
    # Interaction effect
    interaction_key = f'{factor1_name}:{factor2_name}'
    if interaction_key in anova_table.index:
        summary_results[f'{factor1_name} × {factor2_name} interaction'] = {
            'F': anova_table.loc[interaction_key, 'F Value'],
            'p': anova_table.loc[interaction_key, 'Pr > F'],
            'df': (anova_table.loc[interaction_key, 'Num DF'], anova_table.loc[interaction_key, 'Den DF'])
        }
    
    results = {
        'anova_results': anova_results,
        'anova_table': anova_table,
        'summary': summary_results,
        'means': {
            'grand_mean': grand_mean,
            f'{factor1_name}_means': factor1_means.to_dict(),
            f'{factor2_name}_means': factor2_means.to_dict(),
            'interaction_means': interaction_means.to_dict()
        },
        'dataframe': df  # Include the long-format dataframe
    }
    
    # Add posthoc tests if requested
    if posthoc:
        posthoc_results = perform_posthoc_tests(df, factor1_name, factor2_name, 
                                              factor1_levels, factor2_levels, correction)
        results['posthoc'] = posthoc_results
    
    return results

def perform_posthoc_tests(df, factor1_name, factor2_name, factor1_levels, factor2_levels, correction='bonferroni'):
    """
    Perform posthoc paired t-tests for significant effects.
    """
    posthoc_results = {}
    
    # Posthoc tests for factor1 main effect (averaging across factor2)
    factor1_data = {}
    for level in factor1_levels:
        # Get data for this level of factor1, averaged across factor2 levels for each subject
        level_data = df[df[factor1_name] == level].groupby('subject')['value'].mean()
        factor1_data[level] = level_data
    
    # Pairwise comparisons for factor1
    factor1_comparisons = []
    for level1, level2 in combinations(factor1_levels, 2):
        data1 = factor1_data[level1]
        data2 = factor1_data[level2]
        
        # Paired t-test
        t_stat, p_val = stats.ttest_rel(data1, data2)
        
        factor1_comparisons.append({
            'comparison': f'{level1} vs {level2}',
            't_stat': t_stat,
            'p_uncorrected': p_val,
            'mean_diff': data1.mean() - data2.mean(),
            'effect_size_d': (data1.mean() - data2.mean()) / np.sqrt(((data1.std()**2 + data2.std()**2) / 2))
        })
    
    # Apply multiple comparison correction for factor1
    if factor1_comparisons:
        p_values = [comp['p_uncorrected'] for comp in factor1_comparisons]
        corrected_p = apply_correction(p_values, correction)
        for i, comp in enumerate(factor1_comparisons):
            comp['p_corrected'] = corrected_p[i]
    
    posthoc_results[f'{factor1_name}_pairwise'] = factor1_comparisons
    
    # Posthoc tests for factor2 main effect (averaging across factor1)
    factor2_data = {}
    for level in factor2_levels:
        # Get data for this level of factor2, averaged across factor1 levels for each subject
        level_data = df[df[factor2_name] == level].groupby('subject')['value'].mean()
        factor2_data[level] = level_data
    
    # Pairwise comparisons for factor2
    factor2_comparisons = []
    for level1, level2 in combinations(factor2_levels, 2):
        data1 = factor2_data[level1]
        data2 = factor2_data[level2]
        
        # Paired t-test
        t_stat, p_val = stats.ttest_rel(data1, data2)
        
        factor2_comparisons.append({
            'comparison': f'{level1} vs {level2}',
            't_stat': t_stat,
            'p_uncorrected': p_val,
            'mean_diff': data1.mean() - data2.mean(),
            'effect_size_d': (data1.mean() - data2.mean()) / np.sqrt(((data1.std()**2 + data2.std()**2) / 2))
        })
    
    # Apply multiple comparison correction for factor2
    if factor2_comparisons:
        p_values = [comp['p_uncorrected'] for comp in factor2_comparisons]
        corrected_p = apply_correction(p_values, correction)
        for i, comp in enumerate(factor2_comparisons):
            comp['p_corrected'] = corrected_p[i]
    
    posthoc_results[f'{factor2_name}_pairwise'] = factor2_comparisons
    
    # Simple effects tests for interaction (if needed)
    # Test factor2 at each level of factor1
    simple_effects_f2_at_f1 = []
    for f1_level in factor1_levels:
        f2_at_f1_data = {}
        for f2_level in factor2_levels:
            # Get data for this specific combination
            combo_data = df[(df[factor1_name] == f1_level) & (df[factor2_name] == f2_level)]['value']
            f2_at_f1_data[f2_level] = combo_data
        
        # Pairwise comparisons within this level of factor1
        for level1, level2 in combinations(factor2_levels, 2):
            data1 = f2_at_f1_data[level1]
            data2 = f2_at_f1_data[level2]
            
            t_stat, p_val = stats.ttest_rel(data1, data2)
            
            simple_effects_f2_at_f1.append({
                'comparison': f'{level1} vs {level2} at {f1_level}',
                't_stat': t_stat,
                'p_uncorrected': p_val,
                'mean_diff': data1.mean() - data2.mean(),
                'effect_size_d': (data1.mean() - data2.mean()) / np.sqrt(((data1.std()**2 + data2.std()**2) / 2))
            })
    
    # Apply correction for simple effects
    if simple_effects_f2_at_f1:
        p_values = [comp['p_uncorrected'] for comp in simple_effects_f2_at_f1]
        corrected_p = apply_correction(p_values, correction)
        for i, comp in enumerate(simple_effects_f2_at_f1):
            comp['p_corrected'] = corrected_p[i]
    
    posthoc_results[f'{factor2_name}_at_{factor1_name}_levels'] = simple_effects_f2_at_f1
    
    return posthoc_results

def apply_correction(p_values, method='bonferroni'):
    """Apply multiple comparison correction."""
    if method == 'bonferroni':
        return [min(p * len(p_values), 1.0) for p in p_values]
    elif method == 'holm':
        # Holm-Bonferroni method
        sorted_indices = np.argsort(p_values)
        corrected = np.zeros(len(p_values))
        for i, idx in enumerate(sorted_indices):
            corrected[idx] = min(p_values[idx] * (len(p_values) - i), 1.0)
        return corrected.tolist()
    elif method == 'none':
        return p_values
    else:
        raise ValueError("Correction method must be 'bonferroni', 'holm', or 'none'")

def print_rm_anova_results(results):
    """Helper function to nicely print repeated measures ANOVA results"""
    print("2-Way Repeated Measures ANOVA Results")
    print("=" * 60)
    
    for effect, stats in results['summary'].items():
        significance = ("***" if stats['p'] < 0.001 else 
                       "**" if stats['p'] < 0.01 else 
                       "*" if stats['p'] < 0.05 else "")
        
        print(f"{effect}:")
        print(f"  F({stats['df'][0]}, {stats['df'][1]}) = {stats['F']:.3f}, "
              f"p = {stats['p']:.3f}{significance}")
        print()
    
    print("Full ANOVA Table:")
    print(results['anova_table'].round(4))
    
    print(f"\nGrand Mean: {results['means']['grand_mean']:.3f}")
    
    # Print factor means
    for factor_key, means_dict in results['means'].items():
        if factor_key != 'grand_mean' and factor_key != 'interaction_means':
            print(f"\n{factor_key.replace('_means', '')} Means:")
            for level, mean_val in means_dict.items():
                print(f"  {level}: {mean_val:.3f}")
    
    # Print posthoc results if available
    if 'posthoc' in results:
        print("\n" + "="*60)
        print("POSTHOC TESTS (Paired t-tests)")
        print("="*60)
        
        for test_type, comparisons in results['posthoc'].items():
            print(f"\n{test_type.replace('_', ' ').title()}:")
            print("-" * 50)
            
            for comp in comparisons:
                significance = ("***" if comp['p_corrected'] < 0.001 else 
                               "**" if comp['p_corrected'] < 0.01 else 
                               "*" if comp['p_corrected'] < 0.05 else "")
                
                print(f"{comp['comparison']}:")
                print(f"  t = {comp['t_stat']:.3f}, p = {comp['p_corrected']:.3f}{significance}")
                print(f"  Mean diff = {comp['mean_diff']:.3f}, Cohen's d = {comp['effect_size_d']:.3f}")
                print()
