"""
Ablation Study for TAN (Topological Attention Network) - GoEmotions
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, field
import json
import logging
from tqdm import tqdm
from scipy import stats
from sklearn.metrics import (
    f1_score, accuracy_score, precision_score, recall_score,
    hamming_loss, jaccard_score, roc_auc_score
)
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
import warnings
warnings.filterwarnings('ignore')

# Import TAN architecture
from tan_architecture import (
    TANConfig, TANForMultiLabelClassification,
    create_tan_model, TopologicalFeatureExtractor,
    LocalitySensitiveHashing, TopologicalAttention
)

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class AblationConfig:
    """Configuration for ablation studies"""
    model_checkpoint: Path = Path('goemotion_best_model.pt')
    results_dir: Path = Path('ablation_results')
    num_labels: int = 27
    max_seq_length: int = 128
    batch_size: int = 32
    num_evaluation_runs: int = 5  # Multiple evaluations for statistical significance
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Ablation configurations
    ablation_variants: Dict[str, Dict] = field(default_factory=lambda: {
        'Full': {},  # Full model with all components
        'NoTopology': {'use_topology': False},
        'NoLSH': {'use_lsh': False},
        'NoTopologyNoLSH': {'use_topology': False, 'use_lsh': False},
        'SingleHead': {'num_heads': 1},
        'K8': {'k_neighbors': 8},
        'K16': {'k_neighbors': 16},
        'K64': {'k_neighbors': 64},
        'K128': {'k_neighbors': 128},
        'TopologyDim64': {'topology_dim': 64},
        'TopologyDim256': {'topology_dim': 256},
        'NoMultiScale': {'multi_scale_k': None},  # Disable multi-scale k
        'HashBits128': {'hash_bits': 128},
        'HashBits512': {'hash_bits': 512},
        'LessHashes': {'num_hashes': 4},
        'MoreHashes': {'num_hashes': 16},
    })
    
    def __post_init__(self):
        self.results_dir.mkdir(exist_ok=True, parents=True)

class GoEmotionsDataset(Dataset):
    """GoEmotions dataset for multi-label classification"""
    
    def __init__(self, split: str, tokenizer, max_length: int = 128, cache_dir: Path = Path('./data')):
        self.split = split
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.cache_dir = cache_dir
        
        # Load dataset from HuggingFace
        logger.info(f"Loading GoEmotions {split} split...")
        try:
            dataset = load_dataset('go_emotions', split=split, cache_dir=str(cache_dir))
            self.data = dataset
            self.num_labels = 27  # GoEmotions has 27 emotion labels
            logger.info(f"Loaded {len(self.data)} samples for {split}")
        except Exception as e:
            logger.error(f"Failed to load GoEmotions dataset: {e}")
            logger.info("Please ensure you have internet connection or cached data")
            raise
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize text
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Convert labels to multi-hot encoding
        labels = torch.zeros(self.num_labels)
        for label_idx in item['labels']:
            if label_idx < self.num_labels:
                labels[label_idx] = 1
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': labels
        }

class ModelLoader:
    """Load and manage model checkpoints"""
    
    @staticmethod
    def load_checkpoint(checkpoint_path: Path) -> Dict:
        """Load checkpoint with all components"""
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            logger.info(f"Loaded checkpoint from {checkpoint_path}")
            
            # Log checkpoint contents
            logger.info(f"Checkpoint keys: {checkpoint.keys()}")
            
            # Extract metrics if available
            if 'metrics' in checkpoint:
                logger.info(f"Found metrics: {checkpoint['metrics'].keys()}")
                if 'f1_macro' in checkpoint['metrics']:
                    f1 = checkpoint['metrics']['f1_macro']
                    # Fix decimal point error if present
                    if f1 < 0.1:
                        f1 *= 10
                        logger.warning(f"Corrected F1-Macro from {f1/10:.4f} to {f1:.4f}")
                    logger.info(f"F1-Macro: {f1:.4f}")
            
            return checkpoint
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}")
            raise
    
    @staticmethod
    def extract_config(checkpoint: Dict) -> TANConfig:
        """Extract configuration from checkpoint"""
        if 'config' in checkpoint:
            config_dict = checkpoint['config']
            # Create TANConfig from dictionary
            config = TANConfig()
            for key, value in config_dict.items():
                if hasattr(config, key):
                    setattr(config, key, value)
            return config
        else:
            # Default configuration for GoEmotions
            logger.warning("No config found in checkpoint, using defaults")
            return TANConfig(
                vocab_size=30522,  # BERT tokenizer vocab size
                embed_dim=768,
                num_heads=12,
                num_layers=12,
                max_seq_length=128,
                k_neighbors=32,
                use_topology=True,
                use_lsh=True,
                topology_dim=128,
                num_hashes=8,
                hash_bits=256
            )

class AblationEvaluator:
    """Evaluate ablation variants"""
    
    def __init__(self, config: AblationConfig):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        
    def create_ablation_model(self, variant_name: str, base_checkpoint: Dict) -> nn.Module:
        """Create ablation variant of the model"""
        # Extract base configuration
        base_config = ModelLoader.extract_config(base_checkpoint)
        variant_settings = self.config.ablation_variants[variant_name]
        
        # Apply variant modifications
        modified_config = TANConfig(
            vocab_size=base_config.vocab_size,
            embed_dim=base_config.embed_dim,
            num_heads=variant_settings.get('num_heads', base_config.num_heads),
            num_layers=base_config.num_layers,
            max_seq_length=base_config.max_seq_length,
            dropout=base_config.dropout,
            k_neighbors=variant_settings.get('k_neighbors', base_config.k_neighbors),
            use_topology=variant_settings.get('use_topology', base_config.use_topology),
            topology_dim=variant_settings.get('topology_dim', base_config.topology_dim),
            use_lsh=variant_settings.get('use_lsh', base_config.use_lsh),
            num_hashes=variant_settings.get('num_hashes', base_config.num_hashes),
            hash_bits=variant_settings.get('hash_bits', base_config.hash_bits),
            lsh_temperature=variant_settings.get('lsh_temperature', base_config.lsh_temperature),
            multi_scale_k=variant_settings.get('multi_scale_k', base_config.multi_scale_k)
        )
        
        # Create model
        model = TANForMultiLabelClassification(modified_config, self.config.num_labels)
        
        # Load weights from base model where applicable
        if 'model_state_dict' in base_checkpoint:
            self.load_compatible_weights(base_checkpoint['model_state_dict'], model)
        elif 'state_dict' in base_checkpoint:
            self.load_compatible_weights(base_checkpoint['state_dict'], model)
        
        return model
    
    def load_compatible_weights(self, state_dict: Dict, model: nn.Module):
        """Load weights that are compatible with the model architecture"""
        model_dict = model.state_dict()
        
        # Filter out incompatible keys
        compatible_dict = {}
        incompatible_keys = []
        
        for key, value in state_dict.items():
            if key in model_dict:
                if model_dict[key].shape == value.shape:
                    compatible_dict[key] = value
                else:
                    incompatible_keys.append(f"{key} (shape mismatch)")
            else:
                incompatible_keys.append(f"{key} (not in model)")
        
        # Load compatible weights
        model_dict.update(compatible_dict)
        model.load_state_dict(model_dict, strict=False)
        
        logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights")
        if incompatible_keys:
            logger.debug(f"Incompatible keys: {incompatible_keys[:5]}...")
    
    def evaluate_model(self, model: nn.Module, data_loader: DataLoader) -> Dict[str, float]:
        """Evaluate a single model on the dataset"""
        model.eval()
        model.to(self.config.device)
        
        all_preds = []
        all_labels = []
        all_scores = []
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc="Evaluating", leave=False):
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                labels = batch['labels'].to(self.config.device)
                
                # Forward pass
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                
                # Extract loss and logits
                if outputs['loss'] is not None:
                    total_loss += outputs['loss'].item()
                    num_batches += 1
                
                logits = outputs['logits']
                
                # Apply sigmoid for multi-label classification
                scores = torch.sigmoid(logits)
                preds = (scores > 0.5).float()
                
                # Store results
                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
                all_scores.append(scores.cpu().numpy())
        
        # Concatenate all batches
        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)
        all_scores = np.vstack(all_scores)
        
        # Calculate metrics
        metrics = self.calculate_metrics(all_labels, all_preds, all_scores)
        metrics['loss'] = total_loss / max(num_batches, 1)
        
        # Ensure metrics are in correct range
        self.validate_metrics(metrics)
        
        return metrics
    
    def calculate_metrics(self, y_true: np.ndarray, y_pred: np.ndarray, 
                         y_scores: np.ndarray) -> Dict[str, float]:
        """Calculate comprehensive metrics for multi-label classification"""
        metrics = {}
        
        # F1 scores (should be in range [0, 1])
        metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro', zero_division=0)
        metrics['f1_micro'] = f1_score(y_true, y_pred, average='micro', zero_division=0)
        metrics['f1_weighted'] = f1_score(y_true, y_pred, average='weighted', zero_division=0)
        metrics['f1_samples'] = f1_score(y_true, y_pred, average='samples', zero_division=0)
        
        # Other metrics
        metrics['accuracy'] = accuracy_score(y_true, y_pred)
        metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', zero_division=0)
        metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro', zero_division=0)
        metrics['hamming_loss'] = hamming_loss(y_true, y_pred)
        metrics['jaccard_score'] = jaccard_score(y_true, y_pred, average='macro', zero_division=0)
        
        # AUC-ROC if possible
        try:
            metrics['auc_macro'] = roc_auc_score(y_true, y_scores, average='macro')
        except:
            metrics['auc_macro'] = 0.0
        
        # Per-label analysis
        per_label_f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
        metrics['per_label_f1_mean'] = np.mean(per_label_f1)
        metrics['per_label_f1_std'] = np.std(per_label_f1)
        
        return metrics
    
    def validate_metrics(self, metrics: Dict[str, float]):
        """Validate and fix metrics if needed"""
        for key, value in metrics.items():
            if 'loss' not in key and 'hamming' not in key:
                # Most metrics should be in [0, 1]
                if value < 0.1 and 'f1' in key:
                    # Likely decimal point error
                    metrics[key] = value * 10
                    logger.warning(f"Corrected {key} from {value:.4f} to {metrics[key]:.4f}")

class AblationStudy:
    """Complete ablation study with statistical testing"""
    
    def __init__(self, config: AblationConfig):
        self.config = config
        self.evaluator = AblationEvaluator(config)
        
    def run_ablation_study(self) -> Dict:
        """Run complete ablation study"""
        logger.info("="*80)
        logger.info("Starting TAN Ablation Study for GoEmotions")
        logger.info("="*80)
        
        # Load base model checkpoint
        base_checkpoint = ModelLoader.load_checkpoint(self.config.model_checkpoint)
        
        # Create test data loader
        logger.info("\nLoading GoEmotions test dataset...")
        test_dataset = GoEmotionsDataset('test', self.evaluator.tokenizer, self.config.max_seq_length)
        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        logger.info(f"Test dataset: {len(test_dataset)} samples")
        
        # Results storage
        ablation_results = {
            'raw_results': {},
            'comparisons': {},
            'summary': {}
        }
        
        # Evaluate each variant
        for variant_name in self.config.ablation_variants.keys():
            logger.info(f"\n{'='*60}")
            logger.info(f"Evaluating variant: TAN-{variant_name}")
            logger.info(f"{'='*60}")
            
            # Run multiple evaluations for statistical significance
            variant_scores = {
                'f1_macro': [],
                'f1_micro': [],
                'accuracy': [],
                'hamming_loss': [],
                'precision_macro': [],
                'recall_macro': []
            }
            
            for run in range(self.config.num_evaluation_runs):
                logger.info(f"Evaluation run {run+1}/{self.config.num_evaluation_runs}")
                
                # Create variant model
                if variant_name == 'Full':
                    # Use the full model from checkpoint
                    model = TANForMultiLabelClassification(
                        ModelLoader.extract_config(base_checkpoint),
                        self.config.num_labels
                    )
                    if 'model_state_dict' in base_checkpoint:
                        model.load_state_dict(base_checkpoint['model_state_dict'], strict=False)
                    elif 'state_dict' in base_checkpoint:
                        model.load_state_dict(base_checkpoint['state_dict'], strict=False)
                else:
                    # Create ablation variant
                    model = self.evaluator.create_ablation_model(variant_name, base_checkpoint)
                
                # Evaluate
                metrics = self.evaluator.evaluate_model(model, test_loader)
                
                # Store scores
                for key in variant_scores.keys():
                    if key in metrics:
                        variant_scores[key].append(metrics[key])
                
                # Log current results
                logger.info(f"  F1-Macro: {metrics['f1_macro']:.4f}")
                logger.info(f"  F1-Micro: {metrics['f1_micro']:.4f}")
                logger.info(f"  Accuracy: {metrics['accuracy']:.4f}")
                
                # Clean up memory
                del model
                torch.cuda.empty_cache()
            
            # Store results
            ablation_results['raw_results'][f'TAN-{variant_name}'] = variant_scores
            
            # Calculate average performance
            avg_f1_macro = np.mean(variant_scores['f1_macro'])
            std_f1_macro = np.std(variant_scores['f1_macro'])
            logger.info(f"\nAverage F1-Macro: {avg_f1_macro:.4f} ± {std_f1_macro:.4f}")
        
        # Statistical comparisons
        logger.info("\n" + "="*80)
        logger.info("Performing Statistical Comparisons")
        logger.info("="*80)
        ablation_results['comparisons'] = self.compare_variants(ablation_results['raw_results'])
        
        # Generate summary
        ablation_results['summary'] = self.summarize_ablation(ablation_results)
        
        # Save results
        self.save_results(ablation_results)
        
        return ablation_results
    
    def compare_variants(self, results: Dict) -> Dict:
        """Statistical comparison between variants"""
        comparisons = {}
        full_scores = results['TAN-Full']
        
        for variant_name, variant_scores in results.items():
            if variant_name == 'TAN-Full':
                continue
            
            comparison = {}
            
            # Compare F1-Macro scores
            f1_full = full_scores['f1_macro']
            f1_variant = variant_scores['f1_macro']
            
            # Paired t-test
            if len(f1_full) > 1 and len(f1_variant) > 1:
                t_stat, p_value = stats.ttest_rel(f1_full, f1_variant)
                
                # Calculate Cohen's d
                diff = np.array(f1_full) - np.array(f1_variant)
                cohens_d = np.mean(diff) / (np.std(diff, ddof=1) + 1e-8)
                
                comparison['t_test'] = {
                    't_statistic': float(t_stat),
                    'p_value': float(p_value),
                    'cohens_d': float(cohens_d),
                    'significant': p_value < 0.05
                }
            
            # Calculate performance metrics
            comparison['performance_drop'] = np.mean(f1_full) - np.mean(f1_variant)
            comparison['relative_drop_%'] = (comparison['performance_drop'] / np.mean(f1_full)) * 100
            comparison['mean_full'] = np.mean(f1_full)
            comparison['mean_variant'] = np.mean(f1_variant)
            comparison['std_full'] = np.std(f1_full)
            comparison['std_variant'] = np.std(f1_variant)
            
            comparisons[variant_name] = comparison
            
            # Log comparison
            logger.info(f"\n{variant_name}:")
            logger.info(f"  Performance drop: {comparison['performance_drop']:.4f} ({comparison['relative_drop_%']:.1f}%)")
            if 't_test' in comparison:
                logger.info(f"  Statistical significance: p={comparison['t_test']['p_value']:.4f}")
                logger.info(f"  Effect size (Cohen's d): {comparison['t_test']['cohens_d']:.3f}")
        
        return comparisons
    
    def summarize_ablation(self, results: Dict) -> Dict:
        """Generate comprehensive ablation study summary"""
        summary = {
            'component_importance': [],
            'critical_components': [],
            'optimal_configurations': {},
            'key_findings': []
        }
        
        # Analyze component importance
        for variant_name, comparison in results['comparisons'].items():
            component_name = variant_name.replace('TAN-', '')
            
            importance = {
                'component': component_name,
                'performance_drop': comparison['performance_drop'],
                'relative_drop_%': comparison['relative_drop_%'],
                'mean_performance': comparison['mean_variant'],
                'significant': comparison.get('t_test', {}).get('significant', False),
                'p_value': comparison.get('t_test', {}).get('p_value', 1.0),
                'effect_size': abs(comparison.get('t_test', {}).get('cohens_d', 0))
            }
            summary['component_importance'].append(importance)
        
        # Sort by importance (absolute performance drop)
        summary['component_importance'].sort(
            key=lambda x: abs(x['performance_drop']), 
            reverse=True
        )
        
        # Identify critical components (>2% drop and statistically significant)
        summary['critical_components'] = [
            comp['component'] for comp in summary['component_importance']
            if abs(comp['performance_drop']) > 0.02 and comp['significant']
        ]
        
        # Analyze K-neighbors
        k_analysis = {}
        for variant in ['K8', 'K16', 'K64', 'K128']:
            key = f'TAN-{variant}'
            if key in results['raw_results']:
                k_value = int(variant[1:])
                k_analysis[k_value] = np.mean(results['raw_results'][key]['f1_macro'])
        
        # Add default k=32
        k_analysis[32] = np.mean(results['raw_results']['TAN-Full']['f1_macro'])
        
        if k_analysis:
            summary['optimal_configurations']['k_neighbors'] = max(k_analysis, key=k_analysis.get)
            summary['k_analysis'] = k_analysis
        
        # Analyze topology dimension
        topology_dim_analysis = {}
        for variant, dim in [('TopologyDim64', 64), ('TopologyDim256', 256)]:
            key = f'TAN-{variant}'
            if key in results['raw_results']:
                topology_dim_analysis[dim] = np.mean(results['raw_results'][key]['f1_macro'])
        
        # Add default dim=128
        topology_dim_analysis[128] = np.mean(results['raw_results']['TAN-Full']['f1_macro'])
        
        if topology_dim_analysis:
            summary['optimal_configurations']['topology_dim'] = max(
                topology_dim_analysis, 
                key=topology_dim_analysis.get
            )
        
        # Generate key findings
        if 'NoTopology' in results['comparisons']:
            drop = results['comparisons']['NoTopology']['performance_drop']
            pct = results['comparisons']['NoTopology']['relative_drop_%']
            summary['key_findings'].append(
                f"Topological features contribute {drop:.4f} ({pct:.1f}%) to F1-Macro"
            )
        
        if 'NoLSH' in results['comparisons']:
            drop = results['comparisons']['NoLSH']['performance_drop']
            pct = results['comparisons']['NoLSH']['relative_drop_%']
            summary['key_findings'].append(
                f"LSH attention efficiency contributes {drop:.4f} ({pct:.1f}%) to F1-Macro"
            )
        
        if 'NoTopologyNoLSH' in results['comparisons']:
            drop = results['comparisons']['NoTopologyNoLSH']['performance_drop']
            pct = results['comparisons']['NoTopologyNoLSH']['relative_drop_%']
            summary['key_findings'].append(
                f"Combined topology+LSH contributes {drop:.4f} ({pct:.1f}%) to F1-Macro"
            )
        
        # Add finding about critical components
        if summary['critical_components']:
            summary['key_findings'].append(
                f"Critical components (>2% significant drop): {', '.join(summary['critical_components'])}"
            )
        
        # Add finding about optimal configuration
        if summary['optimal_configurations']:
            configs = [f"{k}={v}" for k, v in summary['optimal_configurations'].items()]
            summary['key_findings'].append(
                f"Optimal configurations: {', '.join(configs)}"
            )
        
        return summary
    
    def save_results(self, results: Dict):
        """Save ablation results to file"""
        # Save JSON results
        results_file = self.config.results_dir / 'ablation_results.json'
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        logger.info(f"\nResults saved to {results_file}")
        
        # Create visualizations
        self.create_visualizations(results)
    
    def create_visualizations(self, results: Dict):
        """Create comprehensive ablation visualizations"""
        fig = plt.figure(figsize=(22, 14))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
        
        # 1. Component Importance
        ax1 = fig.add_subplot(gs[0, :2])
        importance_data = results['summary']['component_importance'][:12]
        components = [d['component'] for d in importance_data]
        drops = [d['performance_drop'] * 100 for d in importance_data]
        
        # Color based on significance
        colors = ['#D32F2F' if d['significant'] else '#9E9E9E' for d in importance_data]
        
        bars = ax1.barh(components, drops, color=colors, alpha=0.7, edgecolor='black', linewidth=1)
        ax1.set_xlabel('Performance Drop (% F1-Macro)', fontsize=12)
        ax1.set_title('Component Importance in TAN Architecture', fontsize=14, fontweight='bold')
        ax1.axvline(x=0, color='black', linestyle='-', linewidth=0.8)
        ax1.grid(True, alpha=0.3, axis='x')
        
        # Add significance markers
        for i, (bar, data) in enumerate(zip(bars, importance_data)):
            if data['significant']:
                significance = '***' if data['p_value'] < 0.001 else '**' if data['p_value'] < 0.01 else '*'
                ax1.text(bar.get_width() + 0.2, bar.get_y() + bar.get_height()/2,
                        significance, va='center', fontsize=14, color='red', fontweight='bold')
        
        # 2. K-Neighbors Analysis
        ax2 = fig.add_subplot(gs[0, 2])
        if 'k_analysis' in results['summary']:
            k_values = sorted(list(results['summary']['k_analysis'].keys()))
            f1_scores = [results['summary']['k_analysis'][k] for k in k_values]
            
            ax2.plot(k_values, f1_scores, 'o-', color='#1976D2', linewidth=2.5, markersize=10)
            optimal_k = results['summary']['optimal_configurations'].get('k_neighbors', 32)
            optimal_score = results['summary']['k_analysis'][optimal_k]
            ax2.plot(optimal_k, optimal_score, 'o', color='#D32F2F', markersize=14, 
                    label=f'Optimal k={optimal_k}', zorder=5)
            
            ax2.set_xlabel('K (Number of Neighbors)', fontsize=12)
            ax2.set_ylabel('F1-Macro Score', fontsize=12)
            ax2.set_title('K-Neighbors Sensitivity', fontsize=14, fontweight='bold')
            ax2.grid(True, alpha=0.3)
            ax2.legend(loc='best')
            ax2.set_xscale('log', base=2)
        
        # 3. All Variants Performance
        ax3 = fig.add_subplot(gs[1, :])
        all_variants = list(results['raw_results'].keys())
        variant_means = [np.mean(results['raw_results'][v]['f1_macro']) for v in all_variants]
        variant_stds = [np.std(results['raw_results'][v]['f1_macro']) for v in all_variants]
        
        # Sort by performance
        sorted_indices = np.argsort(variant_means)[::-1]
        all_variants = [all_variants[i] for i in sorted_indices]
        variant_means = [variant_means[i] for i in sorted_indices]
        variant_stds = [variant_stds[i] for i in sorted_indices]
        
        # Clean names
        clean_names = [v.replace('TAN-', '') for v in all_variants]
        
        x_pos = np.arange(len(clean_names))
        colors = ['#2E7D32' if v == 'Full' else '#1976D2' for v in clean_names]
        
        bars = ax3.bar(x_pos, variant_means, yerr=variant_stds, capsize=5,
                      color=colors, alpha=0.7, edgecolor='black', linewidth=1)
        
        ax3.set_xlabel('Model Variant', fontsize=12)
        ax3.set_ylabel('F1-Macro Score', fontsize=12)
        ax3.set_title('Ablation Variants Performance Comparison', fontsize=14, fontweight='bold')
        ax3.set_xticks(x_pos)
        ax3.set_xticklabels(clean_names, rotation=45, ha='right')
        ax3.grid(True, alpha=0.3, axis='y')
        
        # Add baseline line
        ax3.axhline(y=0.6479, color='r', linestyle='--', alpha=0.5, 
                   label='Reported TAN (0.6479)', linewidth=2)
        ax3.legend(loc='upper right')
        
        # Add value labels
        for bar, mean, std in zip(bars, variant_means, variant_stds):
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.005,
                    f'{mean:.3f}', ha='center', va='bottom', fontsize=9)
        
        # 4. Effect Sizes Heatmap
        ax4 = fig.add_subplot(gs[2, 0])
        effect_sizes = []
        labels = []
        
        for comp in results['summary']['component_importance'][:10]:
            effect_sizes.append(comp['effect_size'])
            labels.append(comp['component'])
        
        # Create color map for effect sizes
        effect_matrix = np.array(effect_sizes).reshape(-1, 1)
        im = ax4.imshow(effect_matrix, cmap='RdYlGn_r', aspect='auto', vmin=0, vmax=2)
        
        ax4.set_yticks(range(len(labels)))
        ax4.set_yticklabels(labels)
        ax4.set_xticks([])
        ax4.set_title("Effect Sizes (Cohen's d)", fontsize=14, fontweight='bold')
        
        # Add text annotations
        for i, (label, effect) in enumerate(zip(labels, effect_sizes)):
            ax4.text(0, i, f'{effect:.2f}', ha='center', va='center',
                    color='white' if effect > 1 else 'black', fontweight='bold')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax4, orientation='horizontal', pad=0.1)
        cbar.set_label('Effect Size', fontsize=10)
        
        # 5. Component Synergy Analysis
        ax5 = fig.add_subplot(gs[2, 1])
        if all(['NoTopology' in results['comparisons'], 
                'NoLSH' in results['comparisons'],
                'NoTopologyNoLSH' in results['comparisons']]):
            
            categories = ['Full\nModel', 'No\nTopology', 'No\nLSH', 'No\nBoth']
            scores = [
                np.mean(results['raw_results']['TAN-Full']['f1_macro']),
                np.mean(results['raw_results']['TAN-NoTopology']['f1_macro']),
                np.mean(results['raw_results']['TAN-NoLSH']['f1_macro']),
                np.mean(results['raw_results']['TAN-NoTopologyNoLSH']['f1_macro'])
            ]
            
            colors = ['#2E7D32', '#FF9800', '#03A9F4', '#F44336']
            bars = ax5.bar(categories, scores, color=colors, alpha=0.7, 
                          edgecolor='black', linewidth=2)
            
            ax5.set_ylabel('F1-Macro Score', fontsize=12)
            ax5.set_title('Component Synergy Analysis', fontsize=14, fontweight='bold')
            ax5.set_ylim([min(scores) * 0.95, max(scores) * 1.05])
            ax5.grid(True, alpha=0.3, axis='y')
            
            # Add value labels
            for bar, score in zip(bars, scores):
                ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.002,
                        f'{score:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
        
        # 6. Summary Box
        ax6 = fig.add_subplot(gs[2, 2])
        ax6.axis('off')
        
        # Create summary text
        summary_text = "ABLATION STUDY SUMMARY\n" + "="*35 + "\n\n"
        
        # Full model performance
        full_f1 = np.mean(results['raw_results']['TAN-Full']['f1_macro'])
        full_std = np.std(results['raw_results']['TAN-Full']['f1_macro'])
        summary_text += f"Full Model: {full_f1:.4f} ± {full_std:.4f}\n\n"
        
        # Critical components
        if results['summary']['critical_components']:
            summary_text += "Critical Components:\n"
            for comp in results['summary']['critical_components'][:5]:
                summary_text += f"  • {comp}\n"
        
        summary_text += "\nKey Findings:\n"
        for finding in results['summary']['key_findings'][:4]:
            # Wrap long findings
            if len(finding) > 40:
                words = finding.split()
                line = ""
                for word in words:
                    if len(line) + len(word) > 38:
                        summary_text += f"  {line}\n"
                        line = f"   {word}"
                    else:
                        line += f" {word}" if line else word
                summary_text += f"  {line}\n"
            else:
                summary_text += f"  • {finding}\n"
        
        # Add optimal configuration
        if results['summary']['optimal_configurations']:
            summary_text += f"\nOptimal Config:\n"
            for k, v in results['summary']['optimal_configurations'].items():
                summary_text += f"  {k}: {v}\n"
        
        ax6.text(0.05, 0.95, summary_text, transform=ax6.transAxes,
                fontsize=10, verticalalignment='top', fontfamily='monospace',
                bbox=dict(boxstyle="round,pad=0.7", facecolor="#E3F2FD", 
                         edgecolor="#1976D2", linewidth=2))
        
        plt.suptitle('TAN Ablation Study - GoEmotions Multi-Label Classification', 
                    fontsize=16, fontweight='bold', y=0.98)
        plt.tight_layout()
        
        # Save figure
        save_path = self.config.results_dir / 'ablation_study_visualization.png'
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        logger.info(f"Visualizations saved to {save_path}")

def main():
    """Main execution for ablation study"""
    logger.info("="*80)
    logger.info("TAN ABLATION STUDY FOR GOEMOTIONS")
    logger.info("="*80)
    
    config = AblationConfig()
    
    # Check for model checkpoint
    if not config.model_checkpoint.exists():
        logger.error(f"Model checkpoint not found at {config.model_checkpoint}")
        logger.error("Please ensure goemotion_best_model.pt exists")
        return
    
    # Run ablation study
    study = AblationStudy(config)
    results = study.run_ablation_study()
    
    # Print final summary
    logger.info("\n" + "="*80)
    logger.info("ABLATION STUDY COMPLETE")
    logger.info("="*80)
    
    # Print component importance
    logger.info("\nComponent Importance Ranking:")
    for i, comp in enumerate(results['summary']['component_importance'][:10], 1):
        sig = "***" if comp['significant'] else "   "
        logger.info(
            f"  {i:2d}. {comp['component']:20s}: "
            f"{comp['performance_drop']:+.4f} ({comp['relative_drop_%']:+5.1f}%) "
            f"p={comp['p_value']:.4f} {sig}"
        )
    
    # Print critical components
    if results['summary']['critical_components']:
        logger.info(f"\nCritical Components: {', '.join(results['summary']['critical_components'])}")
    
    # Print key findings
    logger.info("\nKey Findings:")
    for finding in results['summary']['key_findings']:
        logger.info(f"  • {finding}")
    
    # Print optimal configuration
    if results['summary']['optimal_configurations']:
        logger.info("\nOptimal Configuration:")
        for param, value in results['summary']['optimal_configurations'].items():
            logger.info(f"  {param}: {value}")
    
    logger.info("\n" + "="*80)
    logger.info("Results saved to: " + str(config.results_dir))
    logger.info("="*80)

if __name__ == "__main__":
    main()