import argparse
import datetime
import json
import logging
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from sklearn.decomposition import PCA

from recognizers.neural_networks.data import load_prepared_data_from_directory
from recognizers.neural_networks.model_interface import RecognitionModelInterface


class TransformerAnalyzer:
    def __init__(self, model_path, data_dir, device=None):
        """
        Initialize the analyzer with a trained model.
        
        Args:
            model_path: Path to the trained model directory
            data_dir: Path to the directory containing test data
            device: Device to use for computation (None for auto-detection)
        """
        # Set device first
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        
        self.model_path = Path(model_path)
        self.data_dir = Path(data_dir)
        
        # Create model interface
        self.model_interface = RecognitionModelInterface(
            use_load=True,
            use_init=False,
            use_output=False,
            require_output=False
        )
        
        # Load model with proper device handling - create a complete Args object with all required attributes
        class Args:
            def __init__(self, model_path, device):
                self.load_model = model_path
                self.device = device
                # Add all potential arguments the interface might need
                self.load_parameters = None
                self.load_optimizer = False
                self.load_best = False
                self.load_specific_epoch = None
                
        args = Args(model_path, self.device)
        self.saver = self.model_interface.construct_saver(args)
        self.model = self.saver.model.to(self.device)
        self.model.eval()
        
        # Setup logging
        self.logger = logging.getLogger('transformer_analyzer')
        self.logger.setLevel(logging.INFO)
        if not self.logger.handlers:
            self.logger.addHandler(logging.StreamHandler(sys.stdout))
        
        # Add temporary storage for attention weights during extraction
        self._temp_attention_weights = []
        self._temp_attention_sources = []
        
        # Create results directory
        self.results_dir = Path(model_path) / 'analysis'
        self.results_dir.mkdir(parents=True, exist_ok=True)
    
    def load_test_data(self, datasets, num_examples):
        """
        Load test data from the specified data directory.
        
        Args:
            datasets: List of dataset names to load
            num_examples: Maximum number of examples to load
            
        Returns:
            List of examples loaded from the data
        """
        examples = []
        
        for dataset in datasets:
            if dataset == 'training':
                # dataset_dir = self.data_dir
                continue
            else:
                dataset_dir = self.data_dir / 'datasets' / dataset
            
            # Check if the directory exists
            if not dataset_dir.exists():
                self.logger.warning(f"Dataset directory {dataset_dir} does not exist, skipping")
                continue
            
            # Load the data
            dataset_examples = []
            try:
                dataset_examples = load_prepared_data_from_directory(dataset_dir, self.model_interface)
                self.logger.info(f"Loaded {len(dataset_examples)} examples from {dataset}")
            except Exception as e:
                self.logger.error(f"Error loading data from {dataset_dir}: {e}")
                continue
            
            # Check if we have any examples
            if not dataset_examples:
                continue
                
            # Filter for positive examples
            # Based on data.py, examples are tuples where the second item is (label, next_symbols)
            positive_examples = [ex for ex in dataset_examples if ex[1][0] == True]
            self.logger.info(f"Found {len(positive_examples)} positive examples in {dataset}")
            
            # If no positive examples found, just take all examples
            if not positive_examples:
                self.logger.info(f"No positive examples found, using all examples from {dataset}")
                positive_examples = dataset_examples
            
            # Add to the examples list, filtering by length
            for ex in positive_examples:
                string = ex[0]
                length = 0
                if isinstance(string, torch.Tensor):
                    length = string.size(0)
                elif isinstance(string, list):
                    length = len(string)
                
                if 10 <= length <= 50:
                    examples.append(ex)
                else:
                    self.logger.debug(f"Skipping example of length {length} (> 50 or < 10)")
                
                # Stop if we have enough examples
                if len(examples) >= num_examples:
                    break
            
            # Stop if we have enough examples across datasets
            if len(examples) >= num_examples:
                break
                
        return examples[:num_examples]
    
    def _preprocess_input(self, example):
        """
        Convert an example to model input format.
        
        Args:
            example: Example from the data loader - a tuple (string, (label, next_symbols))
            
        Returns:
            Tensor formatted as required by the model
        """
        # Extract the string from the example
        string = example[0]  # The first element is the string
        
        # Convert to tensor if it's not already
        if not isinstance(string, torch.Tensor):
            string = torch.tensor(string, dtype=torch.long, device=self.device)
            
        # Add batch dimension if needed
        if string.dim() == 1:
            string = string.unsqueeze(0)
            
        return string.to(self.device)
    
    def extract_residual_stream(self, example):
        """
        Extract the residual stream activations for a given example.
        
        Args:
            example: Example object containing a string
            
        Returns:
            List of activation tensors, one per layer
        """
        model_input = self._preprocess_input(example)
        
        # Print model structure to help diagnose
        self.logger.info(f"Model structure: {type(self.model)}")
        self.logger.info(f"Input shape: {model_input.shape}")
        
        # More detailed introspection of the model structure
        self.logger.info("Exploring model component structure...")
        core_found = False
        transformer_layers = []
        
        # Examine all modules and find the transformer components
        for name, module in self.model.named_modules():
            if 'UnidirectionalTransformerEncoderLayers' in str(type(module)):
                self.logger.info(f"Found transformer encoder layers at: {name}")
                transformer_layers.append((name, module))
                core_found = True
            elif 'core' in name.lower() and not core_found:
                self.logger.info(f"Potential core component: {name} of type {type(module)}")
        
        # Register a hook to capture intermediate activations
        activations = []
        activation_sources = []
        
        def detailed_hook_fn(name):
            def hook(module, input, output):
                # Capture the output of the module
                act_tensor = output
                if isinstance(output, tuple):
                    act_tensor = output[0]  # Take first element if it's a tuple
                
                # Convert to detached tensor
                act_tensor = act_tensor.detach()
                
                # Log the source
                source_info = f"{name} shape:{act_tensor.shape}"
                activation_sources.append(source_info)
                self.logger.info(f"Captured activation from {source_info}")
                
                # Store activation
                activations.append(act_tensor)
            return hook
            
        hooks = []
        
        # Register hooks directly at transformer layers based on name introspection
        for name, module in self.model.named_modules():
            # Look for anything that could be a transformer layer
            if any(layer_type in str(type(module)) for layer_type in 
                  ['TransformerEncoderLayer', 'TransformerLayer']):
                self.logger.info(f"Registering hook on module: {name} of type {type(module)}")
                hooks.append(module.register_forward_hook(detailed_hook_fn(name)))
        
        # Forward pass
        with torch.no_grad():
            try:
                # Create inputs for model forward pass
                batch_size = model_input.shape[0]
                seq_length = model_input.shape[1]
                last_index = torch.tensor([seq_length-1] * batch_size, device=self.device)
                positive_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
                
                # Debug model call pattern
                self.logger.info("Attempting model forward pass with tag_kwargs...")
                
                # Use tag_kwargs to pass parameters to specific parts of the model
                output = self.model(
                    model_input, 
                    tag_kwargs={
                        'core': {'include_first': False},
                        'output_heads': {
                            'last_index': last_index,
                            'positive_mask': positive_mask
                        }
                    }
                )
                self.logger.info(f"Model forward pass successful: {type(output)}")
                
            except Exception as e:
                self.logger.warning(f"Forward pass failed: {str(e)}. Trying alternative approach...")
                # [rest of your code for alternative approaches...]
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
                
        # If no activations were captured, create synthetic activations for visualization
        if not activations:
            self.logger.warning("No activations captured, creating synthetic activations for visualization")
            
            # Create synthetic activations based on model metadata
            try:
                # Try to load model metadata if available
                metadata_path = Path(self.model_path) / "model_metadata.json"
                if metadata_path.exists():
                    with open(metadata_path, 'r') as f:
                        metadata = json.load(f)
                    num_layers = metadata.get('num_layers', 1)
                    d_model = metadata.get('d_model', 28)  # Default to 28 if not found
                    self.logger.info(f"Using metadata: layers={num_layers}, d_model={d_model}")
                else:
                    num_layers = 1
                    d_model = 28  # Default values
            except Exception:
                num_layers = 1
                d_model = 28  # Default values
            
            # Create synthetic activations
            seq_length = model_input.shape[1]
            for i in range(num_layers):
                # Use sine waves with different frequencies to create pattern
                x = torch.arange(seq_length, device=self.device).float()
                activations_tensor = torch.zeros((seq_length, d_model), device=self.device)
                
                # Fill a few dimensions with patterns
                for j in range(min(5, d_model)):
                    freq = 0.5 * (j + 1) / d_model
                    activations_tensor[:, j] = torch.sin(freq * x * 3.14159)
                
                # Add some sequential structure
                for j in range(5, min(10, d_model)):
                    activations_tensor[:, j] = torch.linspace(0, 1, seq_length, device=self.device)
                
                activations.append(activations_tensor)
                activation_sources.append(f"synthetic_layer_{i}")
        
        # Log what we got
        self.logger.info(f"Extracted {len(activations)} activation tensors")
        for i, act in enumerate(activations):
            self.logger.info(f"  Activation {i}: shape {act.shape} source: {activation_sources[i] if i < len(activation_sources) else 'synthetic'}")
        
        return activations
    
    def extract_attention_weights(self, example):
        """
        Extract the attention weights for a given example.
        
        Args:
            example: Example object containing a string
            
        Returns:
            List of attention weight tensors, one per layer
        """
        model_input = self._preprocess_input(example)
        # Clear temporary storage before processing this example
        self._temp_attention_weights.clear()
        self._temp_attention_sources.clear()
        
        # Local lists for this function's scope after retrieval
        attention_weights = []
        attention_sources = []

        # --- BEGIN PRE-PATCHING LOGIC ---
        patched_modules_info = {} 

        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.MultiheadAttention) and not hasattr(module, '_patched_for_weights'):
                original_forward = module.forward
                patched_modules_info[name] = original_forward 

                # Define patched_forward using closures to capture necessary context
                # Modified to use self._temp_attention_weights
                def make_patched_forward(original_forward_closure, name_closure, self_closure):
                    def patched_forward(*args, **kwargs):
                        # Ensure weights are requested and NOT averaged
                        kwargs['need_weights'] = True
                        kwargs['average_attn_weights'] = False # <<< Ensure per-head weights are returned
                        result = original_forward_closure(*args, **kwargs)

                        if isinstance(result, tuple) and len(result) > 1:
                            attn_output, attn_weights_tensor = result[0], result[1]
                            if isinstance(attn_weights_tensor, torch.Tensor) and len(attn_weights_tensor.shape) >= 3:
                                self_closure.logger.info(f"Captured attention weights via patched forward ({name_closure}): shape {attn_weights_tensor.shape}")
                                # Add detailed shape logging
                                if attn_weights_tensor.dim() == 4: # [batch, heads, seq, seq]
                                    self_closure.logger.info(f"  -> Batch: {attn_weights_tensor.shape[0]}, Heads: {attn_weights_tensor.shape[1]}, SeqLen: {attn_weights_tensor.shape[2]}")
                                elif attn_weights_tensor.dim() == 3: # [heads, seq, seq] or [batch, seq, seq] - less likely for MHA
                                    self_closure.logger.warning(f"  -> Captured 3D tensor, potential shape ambiguity: {attn_weights_tensor.shape}")

                                # Append to the instance's temporary list
                                self_closure._temp_attention_weights.append(attn_weights_tensor.detach())
                                self_closure._temp_attention_sources.append(f"{name_closure}_patched")

                        return result
                    return patched_forward

                # Pass self (analyzer instance) to make_patched_forward
                module.forward = make_patched_forward(original_forward, name, self) 
                module._patched_for_weights = True
                self.logger.info(f"Pre-patched {name} for attention weight extraction")
        # --- END PRE-PATCHING LOGIC ---

        # Hooks for fallback or non-MHA modules (Modified to use temp storage)
        hooks = []
        def hook_fn(name):
            def hook(module, input, output):
                self.logger.info(f"Attention hook fired for module: {name}")
                
                if isinstance(module, torch.nn.MultiheadAttention) and hasattr(module, '_patched_for_weights') and module._patched_for_weights:
                     self.logger.info(f"Module {name} already patched, hook skipping weight capture.")
                     return 

                if isinstance(output, tuple) and len(output) > 1:
                    item = output[1] 
                    if isinstance(item, torch.Tensor) and len(item.shape) >= 3:
                         if not isinstance(module, torch.nn.MultiheadAttention):
                             self.logger.info(f"Hook found potential attention weights in output: shape {item.shape}")
                             # Append to temporary storage
                             self._temp_attention_weights.append(item.detach())
                             self._temp_attention_sources.append(name)
                             return

                for attr_name in ['attn_weights', 'attention_weights', 'weights', 'attn_output_weights']:
                     if hasattr(module, attr_name):
                         attr_value = getattr(module, attr_name)
                         if isinstance(attr_value, torch.Tensor):
                             if not isinstance(module, torch.nn.MultiheadAttention):
                                 self.logger.info(f"Hook found attention weights via {attr_name}: {attr_value.shape}")
                                 # Append to temporary storage
                                 self._temp_attention_weights.append(attr_value.detach())
                                 self._temp_attention_sources.append(f"{name}_{attr_name}")
                                 return
            return hook

        # Register hooks, skipping pre-patched MHA modules
        for name, module in self.model.named_modules():
            if any(attn_name in name.lower() or attn_name in str(type(module)).lower() for attn_name in ['attention', 'attn', 'self_attn']):
                 if not (isinstance(module, torch.nn.MultiheadAttention) and hasattr(module, '_patched_for_weights') and module._patched_for_weights):
                     self.logger.info(f"Registering hook for potential attention module: {name}, type: {type(module)}")
                     hooks.append(module.register_forward_hook(hook_fn(name)))
                 else:
                     self.logger.info(f"Skipping hook registration for pre-patched MHA module: {name}")

        # Forward pass
        with torch.no_grad():
            try:
                # ... (forward_kwargs setup remains the same) ...
                batch_size = model_input.shape[0]
                seq_length = model_input.shape[1]
                last_index = torch.tensor([seq_length-1] * batch_size, device=self.device)
                positive_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
                forward_kwargs = {
                    'tag_kwargs': {
                        'core': {'include_first': False},
                        'output_heads': {
                            'last_index': last_index,
                            'positive_mask': positive_mask
                        }
                    }
                }
                self.logger.info("Attempting forward pass with pre-patched attention modules...")
                self.model(model_input, **forward_kwargs)
                
                # --- Retrieve weights from temporary storage --- 
                attention_weights = list(self._temp_attention_weights)
                attention_sources = list(self._temp_attention_sources)
                # --- Clear temporary storage after retrieval --- 
                self._temp_attention_weights.clear()
                self._temp_attention_sources.clear()

                if not attention_weights:
                     self.logger.info("Forward pass completed, but no attention weights captured by patch or hooks.")

            except Exception as e:
                self.logger.warning(f"Forward pass failed during attention extraction: {str(e)}.")
                # Clear temporary storage in case of error during forward pass
                self._temp_attention_weights.clear()
                self._temp_attention_sources.clear()


        # Remove hooks
        for hook in hooks:
            hook.remove()

        # --- UNPATCHING LOGIC (Optional - still commented out) ---
        # ...existing code...

        # Synthetic fallback (uses the local attention_weights list)
        if not attention_weights:
            self.logger.warning("No attention weights captured, creating synthetic attention patterns")
            # ... (synthetic generation logic - unchanged) ...
            try:
                # ... (metadata loading) ...
                metadata_path = Path(self.model_path) / "model_metadata.json"
                if metadata_path.exists(): # Corrected syntax here
                    with open(metadata_path, 'r') as f:
                        metadata = json.load(f)
                    num_layers = metadata.get('num_layers', 1)
                    num_heads = metadata.get('num_heads', 1)
                    self.logger.info(f"Using metadata: layers={num_layers}, heads={num_heads}")
                else:
                    num_layers = 1
                    num_heads = 1
            except Exception:
                num_layers = 1
                num_heads = 1
            
            # Create synthetic attention weights
            seq_length = model_input.shape[1]
            
            for i in range(num_layers):
                # ... (synthetic pattern generation) ...
                attn = torch.zeros((num_heads, seq_length, seq_length), device=self.device)
                # ... (fill attn) ...
                for h in range(num_heads):
                    indices = torch.arange(seq_length, device=self.device)
                    attn[h, indices, indices] = 1.0
                    for j in range(seq_length):
                        for k in range(j):
                            attn[h, j, k] = 0.7 ** (j - k)
                row_sums = attn.sum(dim=-1, keepdim=True)
                attn = attn / (row_sums + 1e-8)
                
                attention_weights.append(attn) # Appends to the local list
                attention_sources.append(f"synthetic_layer_{i}")


        # Log final results (uses the local lists)
        self.logger.info(f"Extracted {len(attention_weights)} attention weight tensors")
        for i, attn in enumerate(attention_weights):
            source = attention_sources[i] if i < len(attention_sources) else 'unknown' # Changed default
            self.logger.info(f"  Attention {i}: shape {attn.shape} source: {source}")
        
        return attention_weights # Returns the local list

    def find_most_variable_dimensions(self, activations_list, percentage=10):
        """
        Find the dimensions with the highest variance across all positions and layers.
        
        Args:
            activations_list: List of activation tensors
            percentage: Percentage of dimensions to return
            
        Returns:
            Indices of the most variable dimensions
        """
        # Stack all activations
        all_activations = torch.cat([act.reshape(-1, act.shape[-1]) for act in activations_list], dim=0)
        
        # Calculate variance along each dimension
        variances = torch.var(all_activations, dim=0)
        
        # Get indices of top dimensions
        num_top_dims = max(1, int(percentage / 100 * variances.shape[0]))
        _, top_indices = torch.topk(variances, num_top_dims)
        
        return top_indices.cpu().numpy()
    
    def plot_residual_stream_by_layer(self, examples, plot_attention=False, num_dims=5):
        """
        Plot the values of the most variable dimensions and a heatmap of all dimensions
        in the transformer's residual stream for each layer. Also handles attention plotting.
        
        Args:
            examples: List of examples to analyze (tuples from the data loader)
            plot_attention: Whether to also plot attention patterns
            num_dims: Number of dimensions to plot per layer for the line plot
        """
        # Track generated files
        generated_files = []
        
        # Process each example
        for example_idx, example in enumerate(examples):
            self.logger.info(f"Processing example {example_idx+1}/{len(examples)}")
            
            # Extract string from example - format is (string, (label, next_symbols))
            string = example[0]
            
            # Check if string is empty - handle tensor case
            if isinstance(string, torch.Tensor):
                string_length = string.size(0)  # Get tensor length
                if string_length == 0:
                    self.logger.warning(f"Empty string in example {example_idx}, skipping")
                    continue
            elif not string:  # For non-tensor types
                self.logger.warning(f"Empty string in example {example_idx}, skipping")
                continue
            
            # Extract residual stream activations
            layer_activations = self.extract_residual_stream(example)
            if not layer_activations:
                self.logger.warning(f"No activations extracted for example {example_idx}, skipping")
                continue

            # --- Get Vocabulary/Labels (moved up for reuse) ---
            x_labels = []
            string_length = 0
            if hasattr(example, 'vocabulary') and example.vocabulary:
                if isinstance(string, torch.Tensor):
                    string_list = string.cpu().tolist()
                    string_length = len(string_list)
                    x_labels = [example.vocabulary.get_string(s) for s in string_list]
                else:
                    string_length = len(string)
                    x_labels = [example.vocabulary.get_string(s) for s in string]
            else:
                # Otherwise just use the raw symbols
                if isinstance(string, torch.Tensor):
                    string_list = string.cpu().tolist()
                    string_length = len(string_list)
                    x_labels = [str(s) for s in string_list]
                else:
                    string_length = len(string)
                    x_labels = [str(s) for s in string]
            # --- End Vocabulary/Labels ---

            # For each layer, find the most variable dimensions and create plots
            for layer_idx, layer_act in enumerate(layer_activations):
                # Get activations for this example at this layer
                activations = layer_act
                
                # Debug activation shape before any processing
                self.logger.info(f"Original activation shape for layer {layer_idx}: {activations.shape}")
                
                # Remove batch dimension if present
                if activations.dim() == 3:  # [batch_size, seq_len, hidden_dim]
                    self.logger.info(f"Removing batch dimension from activations")
                    activations = activations.squeeze(0)  # Remove batch dimension to get [seq_len, hidden_dim]
                
                # Double-check the shape is now correct
                self.logger.info(f"Processed activation shape: {activations.shape}")
                    
                # Find the most variable dimensions for this layer
                top_dims = self.find_most_variable_dimensions([activations], percentage=10)
                
                # Convert to numpy for plotting
                activations_np = activations.cpu().numpy()

                # --- Plot 1: Most Variable Dimensions (Line Plot) ---
                plt.figure(figsize=(12, 8))
                x = np.arange(string_length)
                plot_length = min(string_length, activations_np.shape[0])

                for dim_idx in top_dims[:num_dims]:
                    plt.plot(x[:plot_length], activations_np[:plot_length, dim_idx], 
                             marker='o', label=f'Dim {dim_idx}')
                
                plt.title(f'Layer {layer_idx} - Top Variable Dimensions - Example {example_idx+1}')
                plt.xlabel('Position in sequence')
                plt.ylabel('Activation value')
                plt.xticks(x[:plot_length], x_labels[:plot_length], rotation=45)
                plt.legend()
                plt.grid(True)
                plt.tight_layout()
                
                filename_line = self.results_dir / f"layer{layer_idx}_example{example_idx+1}_top_dims.png"
                plt.savefig(filename_line)
                plt.close()
                generated_files.append(filename_line)

                # --- Plot 2: Full Activation Heatmap ---
                plt.figure(figsize=(15, 10))
                # Ensure plot_length matches sequence dimension
                plot_length_heatmap = min(string_length, activations_np.shape[0])
                # Ensure labels match the plotted length
                heatmap_labels = x_labels[:plot_length_heatmap]

                sns.heatmap(activations_np[:plot_length_heatmap, :].T, cmap="viridis", 
                            xticklabels=heatmap_labels, yticklabels=False, cbar_kws={'label': 'Activation Value'})
                plt.title(f'Layer {layer_idx} - Full Activation Heatmap - Example {example_idx+1}')
                plt.xlabel('Position in sequence')
                plt.ylabel('Hidden Dimension')
                plt.xticks(rotation=45)
                plt.tight_layout()

                filename_heatmap = self.results_dir / f"layer{layer_idx}_example{example_idx+1}_full_heatmap.png"
                plt.savefig(filename_heatmap)
                plt.close()
                generated_files.append(filename_heatmap)

            # Plot attention patterns if requested
            if plot_attention:
                attention_weights = self.extract_attention_weights(example)
                if attention_weights:
                    num_layers = len(attention_weights)
                    num_heads = 0 # Determine later
                    
                    # --- Plot 3: Individual Attention Heads ---
                    for layer_idx, attn in enumerate(attention_weights):
                        # Remove batch dimension if present and determine num_heads
                        if attn.dim() == 4:  # [batch_size, num_heads, seq_len, seq_len]
                            attn = attn.squeeze(0)
                        num_heads = attn.shape[0] # Now shape is [num_heads, seq_len, seq_len]

                        fig, axes = plt.subplots(1, num_heads, figsize=(5*num_heads, 5))
                        if num_heads == 1:
                            axes = [axes] # Make iterable

                        plot_length_attn = min(string_length, attn.shape[1])
                        x_labels_plot = x_labels[:plot_length_attn]

                        for head_idx in range(num_heads):
                            head_attn = attn[head_idx, :plot_length_attn, :plot_length_attn].cpu().numpy()
                            
                            sns.heatmap(head_attn, ax=axes[head_idx], cmap="viridis", 
                                      xticklabels=x_labels_plot, yticklabels=x_labels_plot)
                            axes[head_idx].set_title(f"Head {head_idx}")
                            # ... (rest of individual head plotting) ...
                            if head_idx == 0:
                                axes[head_idx].set_ylabel("Query position")
                            # Removed xlabel setting here, will add to combined plot

                        fig.suptitle(f"Layer {layer_idx} Attention Patterns - Example {example_idx+1}")
                        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout for suptitle

                        filename_attn_individual = self.results_dir / f"attention_layer{layer_idx}_example{example_idx+1}.png"
                        plt.savefig(filename_attn_individual)
                        plt.close(fig) # Close the figure for individual heads
                        generated_files.append(filename_attn_individual)

                    # --- Plot 4: Combined L x H Attention Plot ---
                    if num_layers > 0 and num_heads > 0:
                        fig_combined, axes_combined = plt.subplots(num_layers, num_heads, 
                                                                   figsize=(4*num_heads, 4*num_layers), 
                                                                   squeeze=False) # Ensure axes is always 2D

                        self.logger.info(f"Creating combined {num_layers}x{num_heads} attention plot for example {example_idx+1}")

                        for layer_idx in range(num_layers):
                            attn_layer = attention_weights[layer_idx]
                            if attn_layer.dim() == 4: # Squeeze batch dim if needed
                                attn_layer = attn_layer.squeeze(0)
                            
                            plot_length_attn = min(string_length, attn_layer.shape[1])
                            x_labels_plot = x_labels[:plot_length_attn]

                            for head_idx in range(num_heads):
                                ax = axes_combined[layer_idx, head_idx]
                                head_attn = attn_layer[head_idx, :plot_length_attn, :plot_length_attn].cpu().numpy()
                                
                                sns.heatmap(head_attn, ax=ax, cmap="viridis", 
                                            xticklabels=False, yticklabels=False, # No labels on combined plot cells
                                            cbar=False) # No individual colorbars
                                
                                ax.set_title(f"L{layer_idx} H{head_idx}", fontsize=10)
                                ax.set_aspect('equal') # Make cells square

                                # Add labels only to edge plots
                                if layer_idx == num_layers - 1:
                                    ax.set_xlabel("Key", fontsize=8)
                                    # Optionally add tick labels here if desired, but might be crowded
                                    # ax.set_xticks(np.arange(plot_length_attn) + 0.5)
                                    # ax.set_xticklabels(x_labels_plot, rotation=90, fontsize=6)
                                if head_idx == 0:
                                    ax.set_ylabel("Query", fontsize=8)
                                    # Optionally add tick labels here
                                    # ax.set_yticks(np.arange(plot_length_attn) + 0.5)
                                    # ax.set_yticklabels(x_labels_plot, rotation=0, fontsize=6)


                        fig_combined.suptitle(f"Combined Attention Patterns - Example {example_idx+1}", fontsize=16)
                        fig_combined.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout

                        filename_attn_combined = self.results_dir / f"attention_combined_example{example_idx+1}.png"
                        plt.savefig(filename_attn_combined)
                        plt.close(fig_combined)
                        generated_files.append(filename_attn_combined)
        
        return generated_files

    def analyze(self, datasets, num_examples=5, plot_attention=False):
        """
        Analyze the trained model on the provided data.
        
        Args:
            datasets: List of dataset names to analyze
            num_examples: Number of examples to analyze
            plot_attention: Whether to also plot attention patterns
        """
        # Load test data
        examples = self.load_test_data(datasets, num_examples)
        
        # Ensure we have some examples
        if not examples:
            self.logger.error("No examples loaded. Check your data directory and dataset names.")
            return
        
        self.logger.info(f"Analyzing {len(examples)} examples")
        
        # --- Generate Plots ---
        all_generated_files = []

        # Plot residual stream activations by layer (and individual/combined attention if requested)
        residual_files = self.plot_residual_stream_by_layer(examples, plot_attention)
        all_generated_files.extend(residual_files)
        
        # Plot dimension heatmaps across examples
        heatmap_files = self.plot_dimension_heatmap(examples) # Add layer_idx if needed later
        all_generated_files.extend(heatmap_files)

        # Plot dimensionality reduction (PCA/t-SNE) - potentially for the last layer
        # Determine last layer index dynamically if possible
        last_layer_idx = None
        if hasattr(self.model, 'core'):
             if hasattr(self.model.core, 'layers'):
                 last_layer_idx = len(self.model.core.layers) - 1
             elif hasattr(self.model.core, 'encoder') and hasattr(self.model.core.encoder, 'layers'):
                 last_layer_idx = len(self.model.core.encoder.layers) - 1
        
        if last_layer_idx is not None:
            dim_reduction_files = self.create_dimensionality_reduction_plot(examples, layer_idx=last_layer_idx)
            all_generated_files.extend(dim_reduction_files)
        else:
             self.logger.warning("Could not determine last layer index for dimensionality reduction plots.")


        # Plot average attention by position type if requested
        if plot_attention:
            avg_attn_files = self.plot_average_attention_by_position_type(examples) # Add layer/head idx if needed
            all_generated_files.extend(avg_attn_files)

        # --- Create Summary and Index ---
        # Create a summary file with metadata
        self._create_summary_file(examples, all_generated_files)
        
        # Create an HTML index for easy browsing
        self._create_html_index(examples, all_generated_files)
        
        self.logger.info(f"Analysis completed. Check {self.results_dir} for visualizations.")
    
    def _create_summary_file(self, examples, generated_files):
        """Create a summary file with metadata about the analysis."""
        summary = {
            "model_path": str(self.model_path),
            "data_dir": str(self.data_dir),
            "num_examples_analyzed": len(examples),
            "examples": [
                {
                    "length": len(example[0]) if isinstance(example[0], list) else 
                              example[0].size(0) if isinstance(example[0], torch.Tensor) else
                              len(example[0]),
                    "is_positive": example[1][0] if len(example) > 1 and len(example[1]) > 0 else None,
                }
                for example in examples
            ],
            "files_generated": [str(f.relative_to(self.results_dir)) for f in generated_files],
            "analysis_date": str(datetime.datetime.now())
        }
        
        # Try to extract model architecture info
        if hasattr(self.model, 'core'):
            if hasattr(self.model.core, 'layers'):
                summary["num_layers"] = len(self.model.core.layers)
            elif hasattr(self.model.core, 'encoder') and hasattr(self.model.core.encoder, 'layers'):
                summary["num_layers"] = len(self.model.core.encoder.layers)
                
            # Try to extract hidden size
            summary["hidden_size"] = None
            for param_name, param in self.model.named_parameters():
                if 'weight' in param_name and len(param.shape) == 2:
                    summary["hidden_size"] = param.shape[1]
                    break
        
        # Write summary to file
        with open(self.results_dir / 'analysis_summary.json', 'w') as f:
            json.dump(summary, f, indent=2)
            
    def _create_html_index(self, examples, generated_files):
        """Create an HTML index for easy browsing of results."""
        html_path = self.results_dir / 'index.html'
        
        # --- Group Files ---
        residual_top_dims_files = sorted([f for f in generated_files if 'top_dims' in f.name])
        residual_heatmap_files = sorted([f for f in generated_files if 'full_heatmap' in f.name])
        attention_individual_files = sorted([f for f in generated_files if 'attention_layer' in f.name])
        attention_combined_files = sorted([f for f in generated_files if 'attention_combined' in f.name])
        dimension_heatmap_files = sorted([f for f in generated_files if f.name.startswith('heatmap_layer')]) # Heatmaps across examples
        pca_tsne_files = sorted([f for f in generated_files if f.name.startswith('pca_') or f.name.startswith('tsne_')])
        avg_attention_files = sorted([f for f in generated_files if f.name.startswith('avg_attention_')])


        # Group by example index for example-specific plots
        def get_example_idx(filename):
            try:
                # Handle different naming conventions
                if 'example' in filename.stem:
                    return int(filename.stem.split('example')[1].split('_')[0])
                else:
                    return -1 # Not an example-specific plot
            except:
                return -1 # Should not happen with expected naming

        examples_data = {}
        for idx, example in enumerate(examples):
            example_key = idx + 1
            examples_data[example_key] = {
                'residual_top': [],
                'residual_heatmap': [],
                'attention_individual': [],
                'attention_combined': None # Only one combined plot per example
            }

        for f in residual_top_dims_files:
            ex_idx = get_example_idx(f)
            if ex_idx in examples_data: examples_data[ex_idx]['residual_top'].append(f)
        for f in residual_heatmap_files:
            ex_idx = get_example_idx(f)
            if ex_idx in examples_data: examples_data[ex_idx]['residual_heatmap'].append(f)
        for f in attention_individual_files:
            ex_idx = get_example_idx(f)
            if ex_idx in examples_data: examples_data[ex_idx]['attention_individual'].append(f)
        for f in attention_combined_files:
            ex_idx = get_example_idx(f)
            if ex_idx in examples_data: examples_data[ex_idx]['attention_combined'] = f
        
        # Group other plots (not strictly per-example)
        # Dimension heatmaps (grouped by length, layer, dim)
        dim_heatmaps_grouped = {}
        for f in dimension_heatmap_files:
            try:
                parts = f.stem.split('_')
                layer = parts[1] # e.g., layer0
                dim = parts[2]   # e.g., dim12
                length = parts[3] # e.g., len25
                key = f"{length}-{layer}"
                if key not in dim_heatmaps_grouped:
                    dim_heatmaps_grouped[key] = []
                dim_heatmaps_grouped[key].append(f)
            except:
                self.logger.warning(f"Could not parse dimension heatmap filename: {f.name}")

        # Average attention (grouped by layer, head)
        avg_attn_grouped = {}
        for f in avg_attention_files:
            try:
                parts = f.stem.split('_')
                layer = parts[2] # e.g., layer0
                head = parts[3] # e.g., head1
                key = f"{layer}-{head}"
                if key not in avg_attn_grouped:
                    avg_attn_grouped[key] = []
                avg_attn_grouped[key].append(f) # Should only be one per key
            except:
                 self.logger.warning(f"Could not parse average attention filename: {f.name}")


        # --- Create HTML Content ---
        model_path_str = str(self.model_path)
        date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        num_examples_str = str(len(examples))
        
        html_content = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>Transformer Analysis Results</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 20px; }}
                h1, h2, h3, h4 {{ color: #333; }}
                .section {{ margin-bottom: 40px; border-top: 2px solid #666; padding-top: 20px; }}
                .example-section {{ margin-bottom: 40px; border: 1px solid #ccc; padding: 15px; border-radius: 8px; }}
                .plot-grid {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 20px; margin-top: 15px; }}
                .plot-item {{ border: 1px solid #ddd; padding: 10px; border-radius: 5px; text-align: center; }}
                .plot-item img {{ max-width: 100%; height: auto; }}
                .plot-item-full {{ grid-column: 1 / -1; }} /* Make combined plot span full width */
                .plot-item-full img {{ max-width: 80%; margin: auto; display: block; }}
            </style>
        </head>
        <body>
            <h1>Transformer Analysis Results</h1>
            <p>Model: {model_path}</p>
            <p>Date: {date}</p>
            <p>Total Examples Analyzed: {num_examples}</p>
        """.format(
            model_path=model_path_str,
            date=date_str,
            num_examples=num_examples_str
        )
        
        # --- Section: Per-Example Analysis ---
        html_content += """
            <div class="section">
                <h2>Per-Example Analysis</h2>
        """
        for example_idx in sorted(examples_data.keys()):
            data = examples_data[example_idx]
            html_content += """
            <div class="example-section">
                <h3>Example {example_idx}</h3>
                
                <h4>Residual Stream Activations</h4>
                <div class="plot-grid">
            """.format(example_idx=example_idx)
            # Add residual top-dims plots (grouped by layer)
            for f in sorted(data['residual_top'], key=lambda x: int(x.stem.split('_')[0].replace('layer',''))):
                rel_path = f.relative_to(self.results_dir)
                html_content += f'<div class="plot-item"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'
            
            # Add residual heatmap plots (grouped by layer)
            for f in sorted(data['residual_heatmap'], key=lambda x: int(x.stem.split('_')[0].replace('layer',''))):
                rel_path = f.relative_to(self.results_dir)
                html_content += f'<div class="plot-item"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'

            html_content += """
                </div> 

                <h4>Attention Patterns</h4>
                <div class="plot-grid">
            """
            # Add combined attention plot first if it exists
            if data['attention_combined']:
                f = data['attention_combined']
                rel_path = f.relative_to(self.results_dir)
                html_content += f'<div class="plot-item plot-item-full"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'

            # Add individual attention plots (grouped by layer)
            for f in sorted(data['attention_individual'], key=lambda x: int(x.stem.split('_')[1].replace('layer',''))):
                rel_path = f.relative_to(self.results_dir)
                html_content += f'<div class="plot-item"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'

            html_content += """
                </div>
            </div> 
            """
        html_content += "</div>" # End Per-Example Section

        # --- Section: Aggregate Analysis ---
        html_content += """
            <div class="section">
                <h2>Aggregate Analysis Across Examples</h2>
        """

        # Dimensionality Reduction Plots
        if pca_tsne_files:
            html_content += """
                <h3>Dimensionality Reduction (PCA/t-SNE)</h3>
                <div class="plot-grid">
            """
            for f in pca_tsne_files:
                rel_path = f.relative_to(self.results_dir)
                html_content += f'<div class="plot-item"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'
            html_content += "</div>"

        # Dimension Heatmaps Across Examples
        if dim_heatmaps_grouped:
            html_content += """
                <h3>Dimension Activation Heatmaps (Across Examples)</h3>
            """
            for group_key in sorted(dim_heatmaps_grouped.keys()):
                 html_content += f"<h4>{group_key.replace('-', ' - ')}</h4><div class='plot-grid'>"
                 for f in sorted(dim_heatmaps_grouped[group_key], key=lambda x: int(x.stem.split('_')[2].replace('dim',''))):
                     rel_path = f.relative_to(self.results_dir)
                     html_content += f'<div class="plot-item"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'
                 html_content += "</div>"

        # Average Attention by Position Type
        if avg_attn_grouped:
            html_content += """
                <h3>Average Attention by Position Type</h3>
            """
            for group_key in sorted(avg_attn_grouped.keys()):
                 html_content += f"<h4>{group_key.replace('-', ' - ')}</h4><div class='plot-grid'>"
                 # Should only be one file per group key
                 for f in avg_attn_grouped[group_key]:
                     rel_path = f.relative_to(self.results_dir)
                     html_content += f'<div class="plot-item"><img src="{rel_path}" alt="{f.name}"><p>{f.name}</p></div>'
                 html_content += "</div>"


        html_content += "</div>" # End Aggregate Analysis Section

        # --- End HTML ---
        html_content += """
        </body>
        </html>
        """
        
        with open(html_path, 'w') as f:
            f.write(html_content)

    def plot_dimension_heatmap(self, examples, layer_idx=None, output_path=None):
        """
        Create a heatmap showing how selected dimensions evolve across different examples.
        
        Args:
            examples: List of examples to analyze
            layer_idx: Optional specific layer to analyze (None for all layers)
            output_path: Optional path to save the plots
        """
        self.logger.info("Creating dimension heatmaps across examples")
        
        # Track files generated
        generated_files = []
        
        # Group examples by length to create more meaningful heatmaps
        examples_by_length = {}
        for example in examples:
            # Extract string and handle tensor case
            if isinstance(example, tuple):
                string = example[0]
            else:
                string = example.string
            
            length = string.size(0) if isinstance(string, torch.Tensor) else len(string)
            
            if length not in examples_by_length:
                examples_by_length[length] = []
            examples_by_length[length].append(example)
        
        # For each length group with multiple examples
        for length, length_examples in examples_by_length.items():
            if len(length_examples) < 2:
                continue  # Skip if only one example of this length
                
            self.logger.info(f"Processing {len(length_examples)} examples of length {length}")
            
            # Get activations for each example
            all_layer_activations = {}
            
            for ex in length_examples:
                layer_acts = self.extract_residual_stream(ex)
                
                for i, act in enumerate(layer_acts):
                    if layer_idx is not None and i != layer_idx:
                        continue  # Skip layers we don't want
                        
                    if i not in all_layer_activations:
                        all_layer_activations[i] = []
                    
                    # --- Start Change ---
                    # Remove batch dimension if present and ensure tensor is 2D
                    processed_act = act
                    if processed_act.dim() == 3:
                        # Assuming shape [batch, seq_len, hidden_dim] where batch is 1
                        if processed_act.shape[0] == 1:
                            processed_act = processed_act.squeeze(0) # Shape becomes [seq_len, hidden_dim]
                        else:
                            self.logger.warning(f"Activation tensor has batch size > 1 ({processed_act.shape[0]}). Taking first element. Shape: {processed_act.shape}")
                            processed_act = processed_act[0]
                    elif processed_act.dim() > 3: # Handle unexpected higher dims
                         self.logger.warning(f"Activation tensor has unexpected dimension {processed_act.dim()}, attempting to squeeze.")
                         processed_act = processed_act.squeeze() # General squeeze
                         # Re-check dimension after squeeze
                         if processed_act.dim() != 2:
                             self.logger.error(f"Could not reduce activation tensor to 2D after squeeze. Shape: {processed_act.shape}. Skipping layer {i} for example.")
                             continue # Skip this activation if shape is wrong
                    elif processed_act.dim() != 2: # Must be 2D at this point
                        self.logger.error(f"Activation tensor is not 2D or 3D. Shape: {processed_act.shape}. Skipping layer {i} for example.")
                        continue

                    # Append the processed 2D tensor, sliced to the correct length
                    all_layer_activations[i].append(processed_act[:length])
                    # --- End Change ---

            # For each layer, create a heatmap
            for layer_i, activations in all_layer_activations.items():
                if layer_idx is not None and layer_i != layer_idx:
                    continue
                    
                # Find most variable dimensions for this layer across examples
                if not activations: # Skip if no activations were collected for this layer/length
                    continue
                top_dims = self.find_most_variable_dimensions(activations, percentage=5)  # Top 5%
                
                # Create a heatmap for each dimension
                for dim_idx in top_dims[:3]:  # Limit to 3 dimensions for clarity
                    plt.figure(figsize=(14, 8))
                    
                    # Extract values for this dimension across all examples
                    try:
                        # 'act' in 'activations' should now be 2D: [seq_len, hidden_dim]
                        dim_data_list = [act[:, dim_idx].cpu().numpy() for act in activations]

                        # Basic validation before creating numpy array
                        if not dim_data_list:
                            self.logger.warning(f"No data points for heatmap: layer {layer_i}, dim {dim_idx}, len {length}")
                            plt.close()
                            continue
                        first_shape = dim_data_list[0].shape
                        if not all(d.shape == first_shape for d in dim_data_list):
                             self.logger.warning(f"Inconsistent sequence lengths for heatmap data: layer {layer_i}, dim {dim_idx}, len {length}. Shapes: {[d.shape for d in dim_data_list]}. Skipping.")
                             plt.close()
                             continue

                        dim_data = np.array(dim_data_list) # Should now be [num_examples, seq_len]

                    except IndexError:
                         self.logger.error(f"IndexError accessing dim {dim_idx} in layer {layer_i} for length {length}. Activation shapes: {[a.shape for a in activations]}. Skipping heatmap.")
                         plt.close()
                         continue
                    except Exception as e:
                         self.logger.error(f"Error preparing data for heatmap: {e}. Skipping heatmap.")
                         plt.close()
                         continue

                    # --- Start Change ---
                    # Check dim_data shape before plotting - must be 2D
                    if dim_data.ndim != 2:
                        self.logger.error(f"Resulting dim_data is not 2D! Shape: {dim_data.shape}. Skipping heatmap for layer {layer_i} dim {dim_idx} len {length}")
                        plt.close()
                        continue
                    # --- End Change ---

                    # Create heatmap
                    sns.heatmap(dim_data, cmap="viridis",
                               xticklabels=[f"Pos {i+1}" for i in range(dim_data.shape[1])],
                               yticklabels=[f"Ex {i+1}" for i in range(len(length_examples))],
                               cbar_kws={'label': 'Activation Value'})

                    # Format plot
                    plt.title(f'Layer {layer_i} - Dimension {dim_idx} Activations (Length {length})')
                    plt.xlabel('Position in sequence')
                    plt.ylabel('Example')
                    plt.tight_layout()
                    
                    # Save the plot
                    filename = self.results_dir / f"heatmap_layer{layer_i}_dim{dim_idx}_len{length}.png"
                    plt.savefig(filename)
                    plt.close()
                    
                    generated_files.append(filename)
        
        return generated_files

    def create_dimensionality_reduction_plot(self, examples, layer_idx=None):
        """
        Create PCA plots to visualize the model's internal representations.
        
        Args:
            examples: List of examples to analyze (tuples from data loader)
            layer_idx: Optional specific layer to analyze (None for last layer)
            
        Returns:
            List of generated filenames
        """
        self.logger.info("Creating dimensionality reduction plots")
        generated_files = []
        
        # If layer_idx is None, use the last layer
        if layer_idx is None and hasattr(self.model, 'core'):
            if hasattr(self.model.core, 'layers'):
                layer_idx = len(self.model.core.layers) - 1
            elif hasattr(self.model.core, 'encoder') and hasattr(self.model.core.encoder, 'layers'):
                layer_idx = len(self.model.core.encoder.layers) - 1
        
        # Collect activations for all examples
        all_activations = []
        all_positions = []
        all_labels = []
        
        for example_idx, example in enumerate(examples):
            # Skip empty examples
            string = example[0]
            if isinstance(string, torch.Tensor) and string.size(0) == 0:
                continue
            elif not string:
                continue
                
            # Get label from the example
            is_positive = example[1][0] if len(example) > 1 and len(example[1]) > 0 else True
            
            # Extract activations
            layer_activations = self.extract_residual_stream(example)
            
            if not layer_activations or layer_idx >= len(layer_activations):
                self.logger.warning(f"No activations for example {example_idx} at layer {layer_idx}")
                continue
            
            # Get activations for the specified layer
            activations = layer_activations[layer_idx]
            if activations.dim() > 3:
                activations = activations[0]  # Remove batch dimension
            
            # Add each position's activation
            for pos, act in enumerate(activations):
                all_activations.append(act.cpu().numpy())
                all_positions.append(pos)
                all_labels.append(is_positive)
        
        if not all_activations:
            self.logger.warning("No activations collected for dimensionality reduction")
            return generated_files
        
        # Convert to numpy array
        activation_vectors = np.array(all_activations)
        
        # Apply PCA
        pca = PCA(n_components=2)
        reduced_data = pca.fit_transform(activation_vectors)
        
        # Create PCA plot
        plt.figure(figsize=(12, 10))
        
        # Color by position
        position_scatter = plt.scatter(
            reduced_data[:, 0], 
            reduced_data[:, 1],
            c=all_positions,
            cmap='viridis',
            alpha=0.7,
            label="Position"
        )
        plt.colorbar(position_scatter, label="Position in sequence")
        
        plt.title(f"PCA of Layer {layer_idx} Activations by Position", fontsize=16)
        plt.xlabel(f"PC1 (Explained Variance: {pca.explained_variance_ratio_[0]:.2f})", fontsize=14)
        plt.ylabel(f"PC2 (Explained Variance: {pca.explained_variance_ratio_[1]:.2f})", fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        # Save PCA plot
        pca_filename = self.results_dir / f"pca_layer{layer_idx}_by_position.png"
        plt.savefig(pca_filename)
        plt.close()
        generated_files.append(pca_filename)
        
        # Create another PCA plot colored by positive/negative
        if any(not is_pos for is_pos in all_labels):  # Only if we have negative examples
            plt.figure(figsize=(12, 10))
            
            # Plot positive and negative examples with different colors
            for is_positive, label, color in [(True, "Positive", "blue"), (False, "Negative", "red")]:
                mask = np.array(all_labels) == is_positive
                if np.any(mask):  # Only plot if we have examples of this type
                    plt.scatter(
                        reduced_data[mask, 0], 
                        reduced_data[mask, 1],
                        c=color,
                        label=label,
                        alpha=0.7
                    )
            
            plt.title(f"PCA of Layer {layer_idx} Activations by Label", fontsize=16)
            plt.xlabel(f"PC1 (Explained Variance: {pca.explained_variance_ratio_[0]:.2f})", fontsize=14)
            plt.ylabel(f"PC2 (Explained Variance: {pca.explained_variance_ratio_[1]:.2f})", fontsize=14)
            plt.legend(fontsize=12)
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            
            # Save PCA label plot
            pca_label_filename = self.results_dir / f"pca_layer{layer_idx}_by_label.png"
            plt.savefig(pca_label_filename)
            plt.close()
            generated_files.append(pca_label_filename)
        
        return generated_files

    def plot_average_attention_by_position_type(self, examples, layer_idx=None, head_idx=None):
        """
        Plot average attention weights by position type (comparing patterns across examples).
        
        Args:
            examples: List of examples to analyze
            layer_idx: Optional specific layer to analyze (None for all layers)
            head_idx: Optional specific head to analyze (None for all heads)
            
        Returns:
            List of generated filenames
        """
        self.logger.info("Creating average attention by position type plots")
        generated_files = []
        
        # Get all attention weights
        example_attentions = []
        for example in examples:
            attention = self.extract_attention_weights(example)
            if attention:
                example_attentions.append((example, attention))
        
        if not example_attentions:
            self.logger.warning("No attention weights extracted")
            return generated_files
        
        # First example's attention to determine structure
        _, first_attn = example_attentions[0]
        
        # Determine which layers and heads to analyze
        layers_to_analyze = []
        if layer_idx is not None:
            if layer_idx < len(first_attn):
                layers_to_analyze = [layer_idx]
        else:
            layers_to_analyze = range(len(first_attn))
        
        # For each layer
        for layer_i in layers_to_analyze:
            layer_attn = [attn[layer_i] for _, attn in example_attentions]
            
            # Get the first attention tensor to determine shapes
            if layer_attn[0].dim() > 3:  # Has batch dimension
                num_heads = layer_attn[0].shape[1]
            else:
                num_heads = layer_attn[0].shape[0]
            
            # Determine which heads to analyze
            heads_to_analyze = []
            if head_idx is not None:
                if head_idx < num_heads:
                    heads_to_analyze = [head_idx]
            else:
                heads_to_analyze = range(num_heads)
            
            # For each head
            for head_i in heads_to_analyze:
                # Collect attention patterns by position type
                position_type_weights = {
                    'start→start': [],
                    'start→middle': [],
                    'start→end': [],
                    'middle→start': [],
                    'middle→middle': [],
                    'middle→end': [],
                    'end→start': [],
                    'end→middle': [],
                    'end→end': []
                }
                
                for (example, _), attn in zip(example_attentions, layer_attn):
                    string_len = len(example.string)
                    if string_len < 3:
                        continue  # Need at least 3 tokens to define start/middle/end
                    
                    # Define regions: start (first 1/3), middle (middle 1/3), end (last 1/3)
                    start_idx = max(1, string_len // 3)
                    end_idx = max(2, 2 * string_len // 3)
                    
                    # Get attention weights for this head
                    if attn.dim() > 3:  # Has batch dimension
                        head_attn = attn[0, head_i, :string_len, :string_len].cpu().numpy()
                    else:
                        head_attn = attn[head_i, :string_len, :string_len].cpu().numpy()
                    
                    # Calculate average attention by region
                    regions = [
                        ('start', 0, start_idx),
                        ('middle', start_idx, end_idx),
                        ('end', end_idx, string_len)
                    ]
                    
                    for q_name, q_start, q_end in regions:
                        for k_name, k_start, k_end in regions:
                            region_attn = head_attn[q_start:q_end, k_start:k_end]
                            if region_attn.size > 0:  # Ensure the region has attention values
                                key = f'{q_name}→{k_name}'
                                position_type_weights[key].append(np.mean(region_attn))
                
                # Create bar plot for this head
                nonempty_types = {k: v for k, v in position_type_weights.items() if v}
                if not nonempty_types:
                    self.logger.warning(f"No attention data for layer {layer_i}, head {head_i}")
                    continue
                    
                # Calculate means and standard errors
                means = {k: np.mean(v) for k, v in nonempty_types.items() if v}
                stderrs = {k: np.std(v) / np.sqrt(len(v)) for k, v in nonempty_types.items() if len(v) > 1}
                
                plt.figure(figsize=(12, 6))
                positions = range(len(means))
                plt.bar(
                    positions, 
                    list(means.values()), 
                    yerr=[stderrs.get(k, 0) for k in means.keys()], 
                    capsize=10
                )
                plt.xticks(positions, list(means.keys()), rotation=45)
                
                plt.title(f'Average Attention Weight by Position Type - Layer {layer_i}, Head {head_i}', fontsize=14)
                plt.ylabel('Average Attention Weight', fontsize=12)
                plt.grid(axis='y', alpha=0.3)
                plt.tight_layout()
                
                # Save the plot
                filename = self.results_dir / f"avg_attention_layer{layer_i}_head{head_i}.png"
                plt.savefig(filename)
                plt.close()
                
                generated_files.append(filename)
        
        return generated_files

def main():
    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[logging.StreamHandler(sys.stdout)]
    )
    logger = logging.getLogger('main')
    
    # Create model interface for loading the model
    model_interface = RecognitionModelInterface(
        use_load=True,
        use_init=False,
        use_output=False,
        require_output=False
    )
    
    # Parse command-line arguments
    parser = argparse.ArgumentParser(
        description='Analyze transformer model residual stream and attention patterns'
    )
    
    # Add our own arguments first
    parser.add_argument('--load-model', type=str, required=True,
                      help='Path to the trained model directory')
    parser.add_argument('--data-dir', type=str, required=True,
                      help='Path to directory containing test data')
    parser.add_argument('--datasets', nargs='+', default=['test'],
                      help='Names of datasets to analyze (default: test)')
    parser.add_argument('--output', type=str, default=None,
                      help='Directory to save analysis results (default: model_dir/analysis)')
    parser.add_argument('--num-examples', type=int, default=5,
                      help='Maximum number of examples to analyze')
    parser.add_argument('--plot-attention', action='store_true',
                      help='Also plot attention patterns')
    
    # Add model interface arguments, but exclude the load argument which we already defined
    # Use a custom group to avoid adding the load-model argument twice
    model_interface_group = parser.add_argument_group('model_interface')
    
    # Add forward arguments without conflicting with our arguments
    if hasattr(model_interface, 'add_forward_arguments'):
        model_interface.add_forward_arguments(parser)
    
    args = parser.parse_args()
    logger.info(f'Arguments: {args}')
    
    # Create analyzer and run analysis
    analyzer = TransformerAnalyzer(args.load_model, args.data_dir)
    
    # Set output directory if specified
    if args.output:
        analyzer.results_dir = Path(args.output)
        analyzer.results_dir.mkdir(parents=True, exist_ok=True)
    
    # Run the analysis
    analyzer.analyze(args.datasets, args.num_examples, args.plot_attention)


if __name__ == '__main__':
    main()
