"""
Attention Analysis for TAN and Baselines
Analyzes attention patterns by loading pre-trained models and using hooks for extraction.
"""
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple
from dataclasses import dataclass
import json
import logging
from tqdm import tqdm
from scipy.stats import entropy
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

# Use Hugging Face datasets library for robust data loading
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig

# Import your TAN architecture from its original file
from tan_architecture import TANForMultiLabelClassification, TANConfig

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Placeholder class to allow loading older checkpoints
class Config:
    pass

# --- FIX FOR JSON SERIALIZATION ---
def convert_numpy_types(obj):
    """A converter function to handle NumPy types when saving to JSON."""
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")


# --- DATA HANDLING ---
class GoEmotionsHFDataset(Dataset):
    """Wraps the Hugging Face GoEmotions dataset for tokenization."""
    def __init__(self, hf_dataset, tokenizer, max_len: int = 128):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.num_labels = 28 # GoEmotions has 27 emotions + 'neutral'

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        text = item['text']
        labels = torch.zeros(self.num_labels)
        labels[item['labels']] = 1
        encoding = self.tokenizer.encode_plus(
            text, add_special_tokens=True, max_length=self.max_len,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten()}

# --- CONFIGURATION ---
@dataclass
class AttentionConfig:
    """Configuration for attention analysis."""
    tan_checkpoint: Path = Path('goemotion_best_model.pt')
    baseline_checkpoints: Dict[str, Tuple[Path, str]] = None
    num_samples: int = 50
    results_dir: Path = Path('attention_analysis_results')
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

    def __post_init__(self):
        self.results_dir.mkdir(exist_ok=True)
        if self.baseline_checkpoints is None:
            self.baseline_checkpoints = {
                'BERT': (Path('checkpoints_goemotions/bert-base-uncased_best.pt'), 'bert-base-uncased'),
                'RoBERTa': (Path('checkpoints_goemotions/roberta-base_best.pt'), 'roberta-base'),
            }

# --- ATTENTION EXTRACTION (WITH HOOKS FOR TAN) ---
class AttentionExtractor:
    """Extracts attention weights from models without modifying their source code."""
    def __init__(self, model_name: str, model: nn.Module):
        self.model_name = model_name
        self.model = model
        self.attentions = []

    def _tan_hook(self, module, input_tuple, output):
        """
        Hook to capture TAN's attention probabilities.
        FIX: It now checks the tensor shape to capture only the 4D attention map
        and ignore the 3D output projection tensor.
        """
        tensor = input_tuple[0]
        # The real attention map has 4 dimensions (batch, heads, seq, seq).
        # The other tensor processed by the same dropout layer has 3.
        if tensor.ndim == 4:
            self.attentions.append(tensor.detach().cpu())

    def extract(self, batch: Dict) -> List[torch.Tensor]:
        """Extracts attention from a batch for any model type."""
        self.attentions = []
        
        if self.model_name == 'TAN':
            hooks = []
            try:
                for layer in self.model.tan.layers:
                    # This hook is placed on the dropout layer inside the attention module
                    hook = layer.attention.dropout.register_forward_hook(self._tan_hook)
                    hooks.append(hook)
                
                with torch.no_grad():
                    self.model(**batch)
            finally:
                for hook in hooks:
                    hook.remove()
            return self.attentions

        else: # For BERT, RoBERTa, etc.
            with torch.no_grad():
                outputs = self.model(**batch, output_attentions=True)
            return [attn.cpu() for attn in outputs.attentions]

# --- ATTENTION METRICS ---
class AttentionMetrics:
    """Calculates metrics like entropy and distance from attention weights."""
    def calculate_entropy(self, attention_layer: torch.Tensor) -> float:
        # Squeeze to handle potential batch dim of 1
        avg_attention = attention_layer.mean(dim=1).squeeze()
        # Ensure it's a 2D matrix for iteration
        if avg_attention.ndim == 1:
             return 0
        entropies = [entropy(dist) for dist in avg_attention.cpu().numpy() if dist.sum() > 1e-9]
        return np.mean(entropies) if entropies else 0

    def calculate_attention_distance(self, attention_layer: torch.Tensor) -> float:
        avg_attention = attention_layer.mean(dim=1).squeeze().cpu().numpy()
        if avg_attention.ndim == 1:
             return 0
        seq_len = avg_attention.shape[0]
        positions = np.arange(seq_len)
        distances = [np.sum(avg_attention[i] * np.abs(positions - i)) for i in range(seq_len)]
        return np.mean(distances)

    def analyze_patterns(self, attention_weights: List[torch.Tensor]) -> Dict:
        layer_metrics = [{
            'layer': i,
            'entropy': self.calculate_entropy(layer_attn),
            'mean_distance': self.calculate_attention_distance(layer_attn),
        } for i, layer_attn in enumerate(attention_weights) if layer_attn is not None and layer_attn.numel() > 0]

        return {
            'layer_metrics': layer_metrics,
            'avg_entropy': np.mean([m['entropy'] for m in layer_metrics]) if layer_metrics else 0,
            'avg_distance': np.mean([m['mean_distance'] for m in layer_metrics]) if layer_metrics else 0,
        }

# --- VISUALIZATION ---
def create_visualizations(results: Dict, save_dir: Path):
    """Creates visualizations comparing attention properties of models."""
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(1, 2, figsize=(16, 7), sharey=False)
    
    models = list(results.keys())
    colors = ['#FFC300' if m == 'TAN' else ('#581845' if m == 'BERT' else '#DAF7A6') for m in models]

    avg_entropies = [res.get('avg_entropy', 0) for res in results.values()]
    axes[0].bar(models, avg_entropies, color=colors, edgecolor='black')
    axes[0].set_title('Average Attention Entropy', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('Entropy (Higher is more distributed)')

    avg_distances = [res.get('avg_distance', 0) for res in results.values()]
    axes[1].bar(models, avg_distances, color=colors, edgecolor='black')
    axes[1].set_title('Average Attention Distance', fontsize=14, fontweight='bold')
    axes[1].set_ylabel('Mean Distance (Tokens)')
    # Set y-axis to start at 0 for distance, as it cannot be negative
    if all(d >= 0 for d in avg_distances):
        axes[1].set_ylim(bottom=0)

    fig.suptitle('Comparative Attention Analysis', fontsize=16, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    save_path = save_dir / 'attention_comparison.png'
    plt.savefig(save_path, dpi=300)
    logger.info(f"Visualizations saved to {save_path}")
    plt.close()

# --- MAIN EXECUTION ---
def main():
    config = AttentionConfig()
    logger.info("=" * 60)
    logger.info("Starting Model Attention Analysis")
    logger.info(f"Using device: {config.device}")
    logger.info("=" * 60)

    # 1. Load Data
    logger.info("Loading GoEmotions dataset...")
    dataset_dict = load_dataset("go_emotions", "simplified")
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    eval_dataset = GoEmotionsHFDataset(dataset_dict['validation'], tokenizer)
    data_loader = DataLoader(eval_dataset, batch_size=1, sampler=torch.utils.data.SubsetRandomSampler(range(config.num_samples)))
    
    # 2. Define Models and Loaders
    models_to_analyze = {'TAN': (config.tan_checkpoint, None)}
    models_to_analyze.update(config.baseline_checkpoints)
    
    all_results = {}
    metrics_analyzer = AttentionMetrics()
    
    # 3. Run Analysis Loop
    for name, (checkpoint_path, hf_name) in models_to_analyze.items():
        logger.info(f"\n--- Analyzing {name} ---")
        if not checkpoint_path.exists():
            logger.warning(f"Checkpoint not found for {name} at {checkpoint_path}. Skipping.")
            continue

        # Load model architecture
        if name == 'TAN':
            tan_config = TANConfig(vocab_size=tokenizer.vocab_size)
            model = TANForMultiLabelClassification(tan_config, num_labels=28)
        else:
            model_config = AutoConfig.from_pretrained(hf_name, output_attentions=True)
            model = AutoModel.from_pretrained(hf_name, config=model_config)
        
        # Load fine-tuned weights
        state_dict = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        model.load_state_dict(state_dict.get('model_state_dict', state_dict.get('state_dict', state_dict)), strict=False)
        model.to(config.device).eval()
        
        # Extract attention using the appropriate method
        extractor = AttentionExtractor(name, model)
        sample_attentions = []
        for batch in tqdm(data_loader, desc=f"Extracting from {name}"):
            batch_on_device = {k: v.to(config.device) for k, v in batch.items()}
            attentions = extractor.extract(batch_on_device)
            if attentions:
                sample_attentions.append(attentions)
        
        if not sample_attentions:
            logger.warning(f"Could not extract attention for {name}.")
            continue
            
        # Average attention maps across all samples
        num_layers = len(sample_attentions[0])
        avg_attention_per_layer = [torch.cat([sample[i] for sample in sample_attentions]).mean(dim=0, keepdim=True) for i in range(num_layers)]

        # Analyze metrics and store results
        all_results[name] = metrics_analyzer.analyze_patterns(avg_attention_per_layer)
        logger.info(f"Analysis for {name} complete. Avg Entropy: {all_results[name]['avg_entropy']:.4f}, Avg Distance: {all_results[name]['avg_distance']:.4f}")

    # 4. Save and Visualize
    if all_results:
        results_path = config.results_dir / 'attention_analysis_results.json'
        with open(results_path, 'w') as f:
            json.dump(all_results, f, indent=2, default=convert_numpy_types)
        logger.info(f"\nFull analysis results saved to {results_path}")
        
        create_visualizations(all_results, config.results_dir)
    
    logger.info("=" * 60)
    logger.info("Attention Analysis Complete.")
    logger.info("=" * 60)

if __name__ == "__main__":
    main()