"""
Fairness Metrics for Clinical Models

This module implements fairness evaluation metrics including:
- Group fairness metrics (PRAUC gap, worse-case PRAUC)
- Max-min fairness analysis
- Demographic parity metrics

The metrics are computed across demographic groups (race, gender, age groups, etc.)
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score, f1_score
import warnings


class FairnessEvaluator:
    """
    Evaluator for fairness metrics across demographic groups
    """
    
    def __init__(self, 
                 sensitive_attributes: List[str] = None,
                 age_bins: List[float] = None,
                 compute_intersectional: bool = True,
                 include_cxr_availability: bool = True):  # 默认启用
        """
        Initialize fairness evaluator
        
        Args:
            sensitive_attributes: List of sensitive attribute columns to evaluate
            age_bins: Bins for age discretization if age is included
            compute_intersectional: Whether to compute intersectional fairness metrics
            include_cxr_availability: Whether to include CXR availability as a sensitive attribute
        """
        if sensitive_attributes is None:
            sensitive_attributes = ['race', 'gender', 'age']
        
        # 如果启用CXR可用性分析，添加到sensitive attributes
        if include_cxr_availability and 'has_cxr' not in sensitive_attributes:
            sensitive_attributes = sensitive_attributes + ['has_cxr']
        
        self.sensitive_attributes = sensitive_attributes
        self.compute_intersectional = compute_intersectional
        self.include_cxr_availability = include_cxr_availability
        
        # Default age bins: [0, 20, 40, 60, 80, 80+]
        if age_bins is None:
            age_bins = [0, 20, 40, 60, 80, float('inf')]
        self.age_bins = age_bins
        
        # Age bin labels
        self.age_labels = []
        for i in range(len(age_bins) - 1):
            if age_bins[i+1] == float('inf'):
                self.age_labels.append(f"{int(age_bins[i])}+")
            else:
                self.age_labels.append(f"{int(age_bins[i])}-{int(age_bins[i+1])}")
        
        # Define ethnicity mapping to 5 main categories (same as dataset.py)
        self.ethnicity_mapping = {
            'WHITE': 'WHITE',
            'WHITE - OTHER EUROPEAN': 'WHITE',
            'WHITE - RUSSIAN': 'WHITE',
            'WHITE - EASTERN EUROPEAN': 'WHITE',
            'WHITE - BRAZILIAN': 'WHITE',
            'BLACK/AFRICAN AMERICAN': 'BLACK/AFRICAN AMERICAN',
            'BLACK/CAPE VERDEAN': 'BLACK/AFRICAN AMERICAN',
            'BLACK/CARIBBEAN ISLAND': 'BLACK/AFRICAN AMERICAN',
            'BLACK/AFRICAN': 'BLACK/AFRICAN AMERICAN',
            'HISPANIC/LATINO - PUERTO RICAN': 'HISPANIC/LATINO',
            'HISPANIC OR LATINO': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - DOMINICAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - GUATEMALAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - SALVADORAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - MEXICAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - CUBAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - COLUMBIAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - HONDURAN': 'HISPANIC/LATINO',
            'HISPANIC/LATINO - CENTRAL AMERICAN': 'HISPANIC/LATINO',
            'ASIAN': 'ASIAN',
            'ASIAN - CHINESE': 'ASIAN',
            'ASIAN - SOUTH EAST ASIAN': 'ASIAN',
            'ASIAN - ASIAN INDIAN': 'ASIAN',
            'ASIAN - KOREAN': 'ASIAN',
            'OTHER': 'OTHER',
            'AMERICAN INDIAN/ALASKA NATIVE': 'OTHER',
            'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'OTHER',
            'SOUTH AMERICAN': 'OTHER',
            'MULTIPLE RACE/ETHNICITY': 'OTHER',
            'UNKNOWN': 'OTHER',  # Merged with OTHER
            'UNABLE TO OBTAIN': 'OTHER',  # Merged with OTHER
            'PATIENT DECLINED TO ANSWER': 'OTHER'  # Merged with OTHER
        }

    def _discretize_age(self, ages: np.ndarray) -> np.ndarray:
        """
        Discretize continuous age values into bins
        
        Args:
            ages: Array of age values
            
        Returns:
            Array of age bin labels
        """
        age_discrete = np.digitize(ages, self.age_bins[1:-1])
        age_labels = [self.age_labels[i] for i in age_discrete]
        return np.array(age_labels)

    def _map_ethnicity(self, ethnicity_values: np.ndarray) -> np.ndarray:
        """
        Map detailed ethnicity values to 5 main categories
        
        Args:
            ethnicity_values: Array of detailed ethnicity values
            
        Returns:
            Array of mapped ethnicity categories
        """
        mapped_values = []
        for value in ethnicity_values:
            if pd.isna(value) or value == '' or value is None:
                mapped_values.append('OTHER')
            else:
                mapped_values.append(self.ethnicity_mapping.get(str(value).upper(), 'OTHER'))
        return np.array(mapped_values)

    def _prepare_demographic_data(self, meta_attrs: pd.DataFrame) -> pd.DataFrame:
        """
        Prepare demographic data for fairness evaluation
        """
        demo_df = meta_attrs.copy()
        
        # Process age if it's in sensitive attributes
        if 'age' in self.sensitive_attributes and 'age' in demo_df.columns:
            demo_df['age_group'] = self._discretize_age(demo_df['age'].values)
            # Replace 'age' with 'age_group' in sensitive attributes
            sensitive_attrs = [attr if attr != 'age' else 'age_group' 
                             for attr in self.sensitive_attributes]
            self.sensitive_attributes = sensitive_attrs
        
        # Process CXR availability if it's in sensitive attributes
        if 'has_cxr' in self.sensitive_attributes and 'has_cxr' in demo_df.columns:
            print("Processing CXR availability for fairness analysis...")
            
            # Convert boolean to string for consistency
            demo_df['has_cxr'] = demo_df['has_cxr'].astype(str)
            
            # Show distribution
            cxr_counts = demo_df['has_cxr'].value_counts()
            print(f"CXR availability distribution:")
            for availability, count in cxr_counts.items():
                print(f"  {availability}: {count} ({count/len(demo_df)*100:.1f}%)")
        
        # Process race/ethnicity mapping if it's in sensitive attributes
        race_columns = ['race', 'ethnicity', 'race_ethnicity']
        for race_col in race_columns:
            if race_col in self.sensitive_attributes and race_col in demo_df.columns:
                print(f"Mapping {race_col} values to 5 main categories...")
                original_values = demo_df[race_col].unique()
                print(f"Original {race_col} values: {sorted([str(v) for v in original_values if pd.notna(v)])}")
                
                demo_df[race_col] = self._map_ethnicity(demo_df[race_col].values)
                
                mapped_values = demo_df[race_col].unique()
                print(f"Mapped {race_col} values: {sorted(mapped_values)}")
                
                # Show mapping counts
                value_counts = demo_df[race_col].value_counts()
                print(f"{race_col} distribution after mapping:")
                for category, count in value_counts.items():
                    print(f"  {category}: {count} ({count/len(demo_df)*100:.1f}%)")
        
        # Handle missing values by creating an "Unknown" category
        for attr in self.sensitive_attributes:
            if attr in demo_df.columns:
                demo_df[attr] = demo_df[attr].fillna('Unknown')
        
        return demo_df

    def compute_group_metrics(self, 
                            y_true: np.ndarray,
                            y_score: np.ndarray, 
                            meta_attrs: pd.DataFrame,
                            task_type: str = 'binary') -> Dict[str, Any]:
        """
        Compute fairness metrics across demographic groups
        
        Args:
            y_true: True labels
            y_score: Predicted scores/probabilities
            meta_attrs: DataFrame with demographic attributes
            task_type: Type of task ('binary', 'multiclass', 'multilabel')
            
        Returns:
            Dictionary containing fairness metrics
        """
        if len(y_true) != len(meta_attrs):
            raise ValueError(f"Length mismatch: y_true ({len(y_true)}) vs meta_attrs ({len(meta_attrs)})")
        
        # Prepare demographic data
        demo_df = self._prepare_demographic_data(meta_attrs)
        
        fairness_results = {}
        
        # Compute metrics for each sensitive attribute
        for attr in self.sensitive_attributes:
            if attr not in demo_df.columns:
                print(f"Warning: Sensitive attribute '{attr}' not found in demographic data")
                continue
                
            attr_results = self._compute_attribute_fairness(
                y_true, y_score, demo_df[attr], attr, task_type
            )
            fairness_results.update(attr_results)
        
        # Compute intersectional fairness if requested
        if self.compute_intersectional and len(self.sensitive_attributes) >= 2:
            intersectional_results = self._compute_intersectional_fairness(
                y_true, y_score, demo_df, task_type
            )
            fairness_results.update(intersectional_results)
        
        return fairness_results

    def _compute_attribute_fairness(self, 
                                  y_true: np.ndarray,
                                  y_score: np.ndarray,
                                  attribute_values: np.ndarray,
                                  attribute_name: str,
                                  task_type: str) -> Dict[str, float]:
        """
        Compute fairness metrics for a single sensitive attribute
        
        Args:
            y_true: True labels
            y_score: Predicted scores
            attribute_values: Values of the sensitive attribute
            attribute_name: Name of the sensitive attribute
            task_type: Type of task
            
        Returns:
            Dictionary with fairness metrics for this attribute
        """
        results = {}
        
        # Get unique groups
        unique_groups = np.unique(attribute_values)
        
        if len(unique_groups) < 2:
            print(f"Warning: Only {len(unique_groups)} group(s) found for attribute '{attribute_name}'")
            return results
        
        group_metrics = {}
        group_sizes = {}
        
        # Compute metrics for each group
        for group in unique_groups:
            mask = attribute_values == group
            if np.sum(mask) < 10:  # Skip groups with too few samples
                print(f"Warning: Group '{group}' in '{attribute_name}' has only {np.sum(mask)} samples, skipping")
                continue
                
            y_true_group = y_true[mask]
            y_score_group = y_score[mask]
            
            try:
                if task_type == 'binary':
                    metrics = self._compute_binary_metrics(y_true_group, y_score_group)
                elif task_type == 'multiclass':
                    metrics = self._compute_multiclass_metrics(y_true_group, y_score_group)
                elif task_type == 'multilabel':
                    metrics = self._compute_multilabel_metrics(y_true_group, y_score_group)
                else:
                    raise ValueError(f"Unsupported task type: {task_type}")
                
                group_metrics[group] = metrics
                group_sizes[group] = np.sum(mask)
                
                # Store individual group metrics
                for metric_name, value in metrics.items():
                    results[f'fairness/{attribute_name}/{group}/{metric_name}'] = float(value)
                    
            except Exception as e:
                print(f"Warning: Error computing metrics for group '{group}' in '{attribute_name}': {e}")
                continue
        
        if len(group_metrics) < 2:
            print(f"Warning: Not enough valid groups for fairness analysis of '{attribute_name}'")
            return results
        
        # Compute fairness metrics
        fairness_metrics = self._compute_fairness_statistics(group_metrics, attribute_name)
        results.update(fairness_metrics)
        
        # Add group size information
        total_samples = sum(group_sizes.values())
        for group, size in group_sizes.items():
            results[f'fairness/{attribute_name}/{group}/sample_size'] = int(size)
            results[f'fairness/{attribute_name}/{group}/sample_proportion'] = float(size / total_samples)
        
        return results

    def _compute_binary_metrics(self, y_true: np.ndarray, y_score: np.ndarray) -> Dict[str, float]:
        """Compute metrics for binary classification - PRAUC only"""
        y_true = y_true.flatten()
        y_score = y_score.flatten()
        
        # Check if we have both classes
        if len(np.unique(y_true)) < 2:
            # Only one class present, return limited metrics
            return {
                'sample_positive_rate': float(np.mean(y_true)),
            }
        
        try:
            prauc = average_precision_score(y_true, y_score)
        except ValueError:
            # Handle edge cases
            prauc = 0.0
        
        return {
            'PRAUC': float(prauc),
            'sample_positive_rate': float(np.mean(y_true)),
        }

    def _compute_multiclass_metrics(self, y_true: np.ndarray, y_score: np.ndarray) -> Dict[str, float]:
        """Compute metrics for multiclass classification - accuracy only (no PRAUC for multiclass)"""
        y_true = y_true.flatten().astype(int)
        
        if y_score.ndim > 1 and y_score.shape[1] > 1:
            y_pred = np.argmax(y_score, axis=1)
        else:
            y_pred = y_score.flatten().astype(int)
        
        return {
            'accuracy': float(accuracy_score(y_true, y_pred)),
        }

    def _compute_multilabel_metrics(self, y_true: np.ndarray, y_score: np.ndarray) -> Dict[str, float]:
        """Compute metrics for multilabel classification - PRAUC only"""
        
        # Compute PRAUC for each label and average
        prauc_scores = []
        for i in range(y_true.shape[1]):
            if len(np.unique(y_true[:, i])) > 1:
                try:
                    prauc = average_precision_score(y_true[:, i], y_score[:, i])
                    prauc_scores.append(prauc)
                except ValueError:
                    continue
        
        avg_prauc = np.mean(prauc_scores) if prauc_scores else 0.0
        
        return {
            'PRAUC': float(avg_prauc),
        }

    def _compute_fairness_statistics(self, 
                                   group_metrics: Dict[str, Dict[str, float]], 
                                   attribute_name: str) -> Dict[str, float]:
        """
        Compute fairness statistics across groups
        
        Args:
            group_metrics: Dictionary mapping group names to their metrics
            attribute_name: Name of the sensitive attribute
            
        Returns:
            Dictionary with fairness statistics
        """
        results = {}
        
        # Get all metric names
        all_metrics = set()
        for metrics in group_metrics.values():
            all_metrics.update(metrics.keys())
        
        for metric_name in all_metrics:
            metric_values = []
            group_names = []
            
            for group, metrics in group_metrics.items():
                if metric_name in metrics:
                    metric_values.append(metrics[metric_name])
                    group_names.append(group)
            
            if len(metric_values) < 2:
                continue
            
            metric_values = np.array(metric_values)
            
            # Overall statistics
            results[f'fairness/{attribute_name}/{metric_name}/mean'] = float(np.mean(metric_values))
            results[f'fairness/{attribute_name}/{metric_name}/std'] = float(np.std(metric_values))
            results[f'fairness/{attribute_name}/{metric_name}/min'] = float(np.min(metric_values))
            results[f'fairness/{attribute_name}/{metric_name}/max'] = float(np.max(metric_values))
            
            # Fairness-specific metrics
            max_val = np.max(metric_values)
            min_val = np.min(metric_values)
            
            # Gap metrics
            results[f'fairness/{attribute_name}/{metric_name}/gap'] = float(max_val - min_val)
            results[f'fairness/{attribute_name}/{metric_name}/ratio'] = float(min_val / max_val) if max_val > 0 else 0.0
            
            # Worst-case and best-case groups
            worst_idx = np.argmin(metric_values)
            best_idx = np.argmax(metric_values)
            
            results[f'fairness/{attribute_name}/{metric_name}/worst_group'] = group_names[worst_idx]
            results[f'fairness/{attribute_name}/{metric_name}/best_group'] = group_names[best_idx]
            results[f'fairness/{attribute_name}/{metric_name}/worst_case'] = float(min_val)
            results[f'fairness/{attribute_name}/{metric_name}/best_case'] = float(max_val)
            
            # Max-min fairness (worst-case performance)
            results[f'fairness/{attribute_name}/{metric_name}/max_min_fairness'] = float(min_val)
        
        return results

    def _compute_intersectional_fairness(self, 
                                       y_true: np.ndarray,
                                       y_score: np.ndarray,
                                       demo_df: pd.DataFrame,
                                       task_type: str) -> Dict[str, float]:
        """
        Compute intersectional fairness metrics
        
        Args:
            y_true: True labels
            y_score: Predicted scores
            demo_df: DataFrame with demographic attributes
            task_type: Type of task
            
        Returns:
            Dictionary with intersectional fairness metrics
        """
        results = {}
        
        # Create intersectional groups (combinations of attributes)
        available_attrs = [attr for attr in self.sensitive_attributes if attr in demo_df.columns]
        
        if len(available_attrs) < 2:
            return results
        
        # Specifically compute age x race intersectional fairness if both are available
        age_attr = None
        race_attr = None
        
        # Find age and race attributes
        for attr in available_attrs:
            if 'age' in attr.lower():
                age_attr = attr
            elif 'race' in attr.lower() or 'ethnicity' in attr.lower():
                race_attr = attr
        
        if age_attr and race_attr:
            print(f"Computing intersectional fairness for {age_attr} x {race_attr}")
            
            # Create combined attribute for age x race
            combined_attr = demo_df[age_attr].astype(str) + "_x_" + demo_df[race_attr].astype(str)
            
            # Count combinations to see which ones have enough samples
            combination_counts = combined_attr.value_counts()
            print(f"Age x Race intersectional groups:")
            for combo, count in combination_counts.items():
                print(f"  {combo}: {count} samples")
            
            intersectional_results = self._compute_attribute_fairness(
                y_true, y_score, combined_attr.values, f"{age_attr}_x_{race_attr}", task_type
            )
            
            results.update(intersectional_results)
        
        # Also compute all other pairs of attributes as before
        for i in range(len(available_attrs)):
            for j in range(i + 1, len(available_attrs)):
                attr1, attr2 = available_attrs[i], available_attrs[j]
                
                # Skip if we already computed age x race above
                if (age_attr and race_attr and 
                    ((attr1 == age_attr and attr2 == race_attr) or 
                     (attr1 == race_attr and attr2 == age_attr))):
                    continue
                
                # Create combined attribute
                combined_attr = demo_df[attr1].astype(str) + "_x_" + demo_df[attr2].astype(str)
                
                intersectional_results = self._compute_attribute_fairness(
                    y_true, y_score, combined_attr.values, f"{attr1}_x_{attr2}", task_type
                )
                
                results.update(intersectional_results)
        
        return results

    def generate_fairness_report(self, fairness_results: Dict[str, Any]) -> str:
        """
        Generate a human-readable fairness report
        """
        report_lines = []
        report_lines.append("=" * 80)
        report_lines.append("FAIRNESS EVALUATION REPORT")
        report_lines.append("=" * 80)
        
        # Group results by attribute
        attr_results = {}
        for key, value in fairness_results.items():
            if key.startswith('fairness/'):
                parts = key.split('/')
                if len(parts) >= 3:
                    attr_name = parts[1]
                    if attr_name not in attr_results:
                        attr_results[attr_name] = {}
                    attr_results[attr_name][key] = value
        
        # Separate individual attributes from intersectional ones
        individual_attrs = {}
        intersectional_attrs = {}
        
        for attr_name, results in attr_results.items():
            if '_x_' in attr_name:
                intersectional_attrs[attr_name] = results
            else:
                individual_attrs[attr_name] = results
        
        # Report individual attributes first
        report_lines.append("\n🔍 INDIVIDUAL ATTRIBUTE ANALYSIS")
        report_lines.append("=" * 50)
        
        for attr_name, results in individual_attrs.items():
            if attr_name == 'has_cxr':
                report_lines.append(f"\n📊 CXR Availability Fairness Analysis")
                report_lines.append("    (Comparing performance between samples with and without CXR)")
                report_lines.append("--------------------------------------------------")
                cxr_counts = {}
                for key, value in results.items():
                    if 'sample_size' in key:
                        cxr_counts[key.split('/')[-1]] = value
                for availability, count in cxr_counts.items():
                    report_lines.append(f"  {availability}: {count} samples")
                report_lines.append("-" * 50)
                continue # Skip to next attribute

            report_lines.append(f"\n📊 Fairness Analysis for: {attr_name.upper()}")
            if attr_name == 'race':
                report_lines.append("    (Using 5 consolidated categories: WHITE, BLACK/AFRICAN AMERICAN, HISPANIC/LATINO, ASIAN, OTHER)")
            elif 'age' in attr_name.lower():
                report_lines.append("    (Age groups: 0-20, 20-40, 40-60, 60-80, 80+)")
            report_lines.append("-" * 50)
            
            # Find ACC gap and worst-case metrics
            acc_gap = None
            acc_worst_case = None
            acc_best_case = None
            worst_group = None
            best_group = None
            
            for key, value in results.items():
                if 'accuracy/gap' in key:
                    acc_gap = value
                elif 'accuracy/worst_case' in key:
                    acc_worst_case = value
                elif 'accuracy/best_case' in key:
                    acc_best_case = value
                elif 'accuracy/worst_group' in key:
                    worst_group = value
                elif 'accuracy/best_group' in key:
                    best_group = value
            
            if acc_gap is not None:
                report_lines.append(f"  ACC Gap: {acc_gap:.4f}")
                if acc_worst_case is not None:
                    report_lines.append(f"  Worst-case ACC: {acc_worst_case:.4f} (Group: {worst_group})")
                if acc_best_case is not None:
                    report_lines.append(f"  Best-case ACC: {acc_best_case:.4f} (Group: {best_group})")
            
                # Fairness assessment
                if acc_gap < 0.05:
                    assessment = "✅ Good fairness"
                elif acc_gap < 0.10:
                    assessment = "⚠️  Moderate fairness concern"
                else:
                    assessment = "❌ Significant fairness concern"
                
                report_lines.append(f"  Assessment: {assessment}")
            
            # Show group-wise ACC performance
            group_metrics = {}
            for key, value in results.items():
                parts = key.split('/')
                if len(parts) >= 4 and parts[3] == 'accuracy' and not any(x in parts for x in ['gap', 'ratio', 'worst_case', 'best_case', 'worst_group', 'best_group', 'max_min_fairness', 'mean', 'std', 'min', 'max']):
                    group_name = parts[2]
                    if group_name not in group_metrics:
                        group_metrics[group_name] = {}
                    group_metrics[group_name]['accuracy'] = value
            
            # Also collect sample sizes
            for key, value in results.items():
                parts = key.split('/')
                if len(parts) >= 4 and parts[3] == 'sample_size':
                    group_name = parts[2]
                    if group_name not in group_metrics:
                        group_metrics[group_name] = {}
                    group_metrics[group_name]['sample_size'] = value
        
            if group_metrics:
                report_lines.append(f"\n  Group-wise ACC Performance:")
                # Sort by ACC for better readability
                sorted_groups = sorted(group_metrics.items(), 
                                     key=lambda x: x[1].get('accuracy', 0), reverse=True)
                for group, metrics in sorted_groups:
                    if 'accuracy' in metrics:
                        acc = metrics['accuracy']
                        size = metrics.get('sample_size', 'N/A')
                        report_lines.append(f"    {group}: ACC={acc:.4f}, Samples={size}")
        
        # Report intersectional attributes
        if intersectional_attrs:
            report_lines.append(f"\n\n🔄 INTERSECTIONAL ANALYSIS")
            report_lines.append("=" * 50)
            
            for attr_name, results in intersectional_attrs.items():
                report_lines.append(f"\n Intersectional ACC Analysis: {attr_name.upper().replace('_X_', ' × ')}")
                report_lines.append("-" * 50)
                
                # Find ACC gap and worst-case metrics
                acc_gap = None
                acc_worst_case = None
                acc_best_case = None
                worst_group = None
                best_group = None
                
                for key, value in results.items():
                    if 'accuracy/gap' in key:
                        acc_gap = value
                    elif 'accuracy/worst_case' in key:
                        acc_worst_case = value
                    elif 'accuracy/best_case' in key:
                        acc_best_case = value
                    elif 'accuracy/worst_group' in key:
                        worst_group = value
                    elif 'accuracy/best_group' in key:
                        best_group = value
                
                if acc_gap is not None:
                    report_lines.append(f"  Intersectional ACC Gap: {acc_gap:.4f}")
                    if acc_worst_case is not None:
                        worst_formatted = worst_group.replace('_x_', ' × ') if worst_group else 'N/A'
                        report_lines.append(f"  Worst-case ACC: {acc_worst_case:.4f} (Group: {worst_formatted})")
                    if acc_best_case is not None:
                        best_formatted = best_group.replace('_x_', ' × ') if best_group else 'N/A'
                        report_lines.append(f"  Best-case ACC: {acc_best_case:.4f} (Group: {best_formatted})")
                
                    # Intersectional fairness assessment (typically higher gaps are expected)
                    if acc_gap < 0.08:
                        assessment = "✅ Good intersectional fairness"
                    elif acc_gap < 0.15:
                        assessment = "⚠️  Moderate intersectional fairness concern"
                    else:
                        assessment = "❌ Significant intersectional fairness concern"
                    
                    report_lines.append(f"  Assessment: {assessment}")
                
                # Show intersectional group performance (top 5 best and worst)
                group_metrics = {}
                for key, value in results.items():
                    parts = key.split('/')
                    if len(parts) >= 4 and parts[3] == 'accuracy' and not any(x in parts for x in ['gap', 'ratio', 'worst_case', 'best_case', 'worst_group', 'best_group', 'max_min_fairness', 'mean', 'std', 'min', 'max']):
                        group_name = parts[2]
                        if group_name not in group_metrics:
                            group_metrics[group_name] = {}
                        group_metrics[group_name]['accuracy'] = value
                
                # Also collect sample sizes
                for key, value in results.items():
                    parts = key.split('/')
                    if len(parts) >= 4 and parts[3] == 'sample_size':
                        group_name = parts[2]
                        if group_name not in group_metrics:
                            group_metrics[group_name] = {}
                        group_metrics[group_name]['sample_size'] = value
                
                if group_metrics:
                    # Sort by ACC and show top/bottom groups
                    sorted_groups = sorted(group_metrics.items(), 
                                         key=lambda x: x[1].get('accuracy', 0), reverse=True)
                    
                    if len(sorted_groups) > 10:
                        report_lines.append(f"\n  Top 5 Performing Intersectional Groups:")
                        for group, metrics in sorted_groups[:5]:
                            if 'accuracy' in metrics:
                                acc = metrics['accuracy']
                                size = metrics.get('sample_size', 'N/A')
                                formatted_group = group.replace('_x_', ' × ')
                                report_lines.append(f"    {formatted_group}: ACC={acc:.4f}, Samples={size}")
                        
                        report_lines.append(f"\n  Bottom 5 Performing Intersectional Groups:")
                        for group, metrics in sorted_groups[-5:]:
                            if 'accuracy' in metrics:
                                acc = metrics['accuracy']
                                size = metrics.get('sample_size', 'N/A')
                                formatted_group = group.replace('_x_', ' × ')
                                report_lines.append(f"    {formatted_group}: ACC={acc:.4f}, Samples={size}")
                    else:
                        report_lines.append(f"\n  All Intersectional Groups Performance:")
                        for group, metrics in sorted_groups:
                            if 'accuracy' in metrics:
                                acc = metrics['accuracy']
                                size = metrics.get('sample_size', 'N/A')
                                formatted_group = group.replace('_x_', ' × ')
                                report_lines.append(f"    {formatted_group}: ACC={acc:.4f}, Samples={size}")
        
        report_lines.append("\n" + "=" * 80)
        
        return "\n".join(report_lines)


def compute_fairness_metrics(y_true: np.ndarray,
                           y_score: np.ndarray,
                           meta_attrs: pd.DataFrame,
                           task_type: str = 'binary',
                           sensitive_attributes: List[str] = None,
                           age_bins: List[float] = None,
                           compute_intersectional: bool = False,
                           include_cxr_availability: bool = True) -> Tuple[Dict[str, Any], str]:  # 默认启用
    """
    Convenience function to compute fairness metrics
    """
    print(f"=== FAIRNESS METRICS DEBUG ===")
    print(f"y_true shape: {y_true.shape}")
    print(f"y_score shape: {y_score.shape}")
    print(f"meta_attrs shape: {meta_attrs.shape}")
    print(f"meta_attrs columns: {meta_attrs.columns.tolist()}")
    print(f"sensitive_attributes: {sensitive_attributes}")
    print(f"task_type: {task_type}")
    print(f"include_cxr_availability: {include_cxr_availability}")
    print(f"=== END DEBUG ===")
    
    evaluator = FairnessEvaluator(
        sensitive_attributes=sensitive_attributes,
        age_bins=age_bins,
        compute_intersectional=compute_intersectional,
        include_cxr_availability=include_cxr_availability  # 新增参数
    )
    
    fairness_results = evaluator.compute_group_metrics(
        y_true, y_score, meta_attrs, task_type
    )
    
    fairness_report = evaluator.generate_fairness_report(fairness_results)
    
    return fairness_results, fairness_report
