"""
Enhanced comprehensive probe analysis script for APO.

IMPROVEMENTS IN THIS VERSION:
- Separate training and evaluation samples:
  * Training samples (for computing directions) can be any size
  * Evaluation samples (for measuring accuracy) always >= 500
  * Use --max-samples for training, --min-eval-samples for minimum evaluation size
- Significantly improved plot aesthetics:
  * Modern color schemes with custom palette
  * Better typography and spacing
  * Enhanced visual hierarchy
  * Professional styling with shadows and borders
  * Cleaner grid lines and backgrounds
  * Improved legends and annotations
  * Higher quality output (300 DPI)
  * Better use of whitespace

NEW FEATURES:
- Cross-stage comparison plots (layer-wise accuracy evolution)
- Direction transfer analysis (test pretrained directions on SFT/PO)
- Feature importance evolution tracking
- Statistical significance testing
- Summary dashboard visualizations
- Probe transferability metrics
"""

import argparse
import json
import os
from copy import deepcopy
import random
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, silhouette_score
from sklearn.model_selection import train_test_split, cross_val_score
from scipy.stats import ttest_ind, spearmanr
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import warnings
warnings.filterwarnings('ignore')

from dataset_utils import load_preference_dataset

# Set better default style with improved aesthetics
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.1)
sns.set_palette("Set2")

# Custom color palette
COLORS = {
    'pretrained': '#3498db',  # Blue
    'sft': '#2ecc71',         # Green
    'po': '#e74c3c',          # Red
    'chosen': '#3498db',      # Blue
    'rejected': '#e74c3c',    # Red
    'direction': '#f39c12',   # Orange
    'neutral': '#95a5a6'      # Gray
}

# Plot configuration
PLOT_CONFIG = {
    'dpi': 150,
    'bbox_inches': 'tight',
    'facecolor': 'white',
    'edgecolor': 'none'
}


def get_model_layers(model):
    """Get the transformer layers from a model."""
    if hasattr(model, 'model') and hasattr(model.model, 'layers'):
        return model.model.layers
    elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        return model.transformer.h
    elif hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
        return model.model.decoder.layers
    elif hasattr(model, 'language_model') and hasattr(model.language_model, 'layers'):
        return model.language_model.layers
    else:
        raise ValueError("Could not find transformer layers in model architecture")


def get_completion_token_indices(tokenizer, full_messages: List[Dict], prompt_messages: List[Dict]) -> Tuple[int, int]:
    """Identify which tokens belong to the completion (assistant response)."""
    full_tokens = tokenizer.apply_chat_template(
        full_messages, 
        return_tensors="pt", 
        add_generation_prompt=False
    )
    prompt_tokens = tokenizer.apply_chat_template(
        prompt_messages, 
        return_tensors="pt", 
        add_generation_prompt=True
    )
    
    return prompt_tokens.shape[1], full_tokens.shape[1]


def extract_activations_single_layer(
    model,
    tokenizer,
    dataset: List[Dict],
    layer_idx: int,
    max_samples: int = 500,
    position: str = "completion_mean"
) -> Tuple[np.ndarray, np.ndarray]:
    """Extract activations from a single layer (memory efficient)."""
    layers = get_model_layers(model)
    
    chosen_activations = []
    rejected_activations = []
    
    activations_cache = {}
    
    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            activations_cache['output'] = output[0].detach()
        else:
            activations_cache['output'] = output.detach()
    
    hook = layers[layer_idx].register_forward_hook(hook_fn)
    
    model.eval()
    samples = dataset[:max_samples]

    for item in samples:
        prompt = item["prompt"]
        chosen = item["chosen"]
        rejected = item["rejected"]
        prompt_only = deepcopy(prompt)

        # Process chosen
        chosen_full = deepcopy(prompt)
        chosen_full.extend(chosen)
        completion_start, completion_end = get_completion_token_indices(tokenizer, chosen_full, prompt_only)
        
        chosen_enc = tokenizer.apply_chat_template(chosen_full, return_tensors="pt").to(model.device)
        
        activations_cache.clear()
        with torch.no_grad():
            model(input_ids=chosen_enc)
        
        hidden_states = activations_cache['output'].squeeze(0)
        
        if position == "completion_mean":
            activation = hidden_states[completion_start:completion_end].mean(dim=0)
        elif position == "last_token":
            activation = hidden_states[-1]
        elif position == "max":
            activation = hidden_states[completion_start:completion_end].max(dim=0).values
        elif position == "last_prompt":
            activation = hidden_states[completion_start - 1]
        else:  # mean
            activation = hidden_states.mean(dim=0)
        
        chosen_activations.append(activation.float().cpu().numpy())

        # Process rejected
        rejected_full = deepcopy(prompt)
        rejected_full.extend(rejected)
        completion_start, completion_end = get_completion_token_indices(tokenizer, rejected_full, prompt_only)
        
        rejected_enc = tokenizer.apply_chat_template(rejected_full, return_tensors="pt").to(model.device)
        
        activations_cache.clear()
        with torch.no_grad():
            model(input_ids=rejected_enc)
        
        hidden_states = activations_cache['output'].squeeze(0)
        
        if position == "completion_mean":
            activation = hidden_states[completion_start:completion_end].mean(dim=0)
        elif position == "last_token":
            activation = hidden_states[-1]
        elif position == "max":
            activation = hidden_states[completion_start:completion_end].max(dim=0).values
        elif position == "last_prompt":
            activation = hidden_states[completion_start - 1]
        else:
            activation = hidden_states.mean(dim=0)
        
        rejected_activations.append(activation.float().cpu().numpy())

    hook.remove()
    
    return np.vstack(chosen_activations), np.vstack(rejected_activations)


def extract_activations_all_layers(
    model,
    tokenizer,
    dataset: List[Dict],
    max_samples: int = 500,
    position: str = "completion_mean"
) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
    """Extract activations from all layers.
    
    Args:
        model: The model to extract activations from
        tokenizer: Tokenizer for the model
        dataset: Dataset to extract activations from
        max_samples: Number of samples to extract (can be less than 500 for training)
        position: Position to extract activations from
        
    Returns:
        Tuple of (chosen_by_layer, rejected_by_layer) dictionaries
    """
    print(f"\nExtracting activations from all layers (samples: {max_samples}, position: {position})...")
    
    layers = get_model_layers(model)
    num_layers = len(layers)
    
    chosen_by_layer = {}
    rejected_by_layer = {}
    
    for layer_idx in tqdm(range(num_layers), desc="Extracting layers"):
        chosen, rejected = extract_activations_single_layer(
            model, tokenizer, dataset, layer_idx, max_samples, position
        )
        chosen_by_layer[layer_idx] = chosen
        rejected_by_layer[layer_idx] = rejected
    
    print(f"Extracted activations shape: {chosen_by_layer[0].shape}")
    return chosen_by_layer, rejected_by_layer


def compute_mean_difference_direction(pos_acts: np.ndarray, neg_acts: np.ndarray) -> np.ndarray:
    """Compute direction as mean(positive) - mean(negative)."""
    direction = pos_acts.mean(axis=0) - neg_acts.mean(axis=0)
    return direction / (np.linalg.norm(direction) + 1e-8)


def compute_logistic_regression_direction(pos_acts: np.ndarray, neg_acts: np.ndarray) -> np.ndarray:
    """Compute direction using logistic regression weights."""
    X = np.vstack([pos_acts, neg_acts])
    y = np.concatenate([np.ones(len(pos_acts)), np.zeros(len(neg_acts))])
    
    lr = LogisticRegression(max_iter=1000, random_state=42)
    lr.fit(X, y)
    
    direction = lr.coef_[0]
    return direction / (np.linalg.norm(direction) + 1e-8)


def compute_all_directions(pos_acts: np.ndarray, neg_acts: np.ndarray) -> Dict[str, np.ndarray]:
    """Compute directions using all methods."""
    return {
        'mean_diff': compute_mean_difference_direction(pos_acts, neg_acts),
        'logistic_regression': compute_logistic_regression_direction(pos_acts, neg_acts),
    }


def compute_separation_metrics(
    pos_acts: np.ndarray, 
    neg_acts: np.ndarray, 
    direction: np.ndarray,
    use_cross_val: bool = True
) -> Tuple[Dict[str, float], np.ndarray, np.ndarray]:
    """Compute various metrics quantifying pos-neg separation."""
    
    # Project onto direction
    pos_proj = pos_acts @ direction
    neg_proj = neg_acts @ direction
    
    metrics = {}
    
    # 1. Cohen's d
    mean_diff = pos_proj.mean() - neg_proj.mean()
    pooled_std = np.sqrt((pos_proj.std()**2 + neg_proj.std()**2) / 2)
    metrics['cohens_d'] = mean_diff / (pooled_std + 1e-8)
    
    # 2. Classification accuracy with proper evaluation
    X = np.vstack([pos_acts, neg_acts])
    y = np.concatenate([np.ones(len(pos_acts)), np.zeros(len(neg_acts))])
    
    if use_cross_val and len(X) >= 10:
        lr = LogisticRegression(max_iter=1000, random_state=42)
        cv_scores = cross_val_score(lr, X, y, cv=5, scoring='accuracy')
        metrics['classification_accuracy'] = cv_scores.mean()
        metrics['classification_accuracy_std'] = cv_scores.std()
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        lr = LogisticRegression(max_iter=1000, random_state=42)
        lr.fit(X_train, y_train)
        metrics['classification_accuracy'] = accuracy_score(y_test, lr.predict(X_test))
        metrics['classification_accuracy_std'] = 0.0
    
    # 3. Euclidean distance between means
    metrics['euclidean_distance'] = np.linalg.norm(pos_acts.mean(axis=0) - neg_acts.mean(axis=0))
    
    return metrics, pos_proj, neg_proj


def analyze_layer_wise_accuracy(
    chosen_by_layer: Dict[int, np.ndarray],
    rejected_by_layer: Dict[int, np.ndarray],
    chosen_eval_by_layer: Dict[int, np.ndarray] = None,
    rejected_eval_by_layer: Dict[int, np.ndarray] = None,
    methods: List[str] = ['mean_diff', 'logistic_regression']
) -> Dict[int, Dict[str, Dict]]:
    """Train probes on each layer and measure accuracy with proper evaluation.
    
    Args:
        chosen_by_layer: Training activations for chosen responses
        rejected_by_layer: Training activations for rejected responses
        chosen_eval_by_layer: Evaluation activations for chosen responses (if None, use training data)
        rejected_eval_by_layer: Evaluation activations for rejected responses (if None, use training data)
        methods: List of methods to use for computing directions
    """
    print("\nAnalyzing layer-wise probe accuracy...")
    
    # Use training data for eval if not provided
    if chosen_eval_by_layer is None:
        chosen_eval_by_layer = chosen_by_layer
    if rejected_eval_by_layer is None:
        rejected_eval_by_layer = rejected_by_layer

    results_by_layer = {}

    for layer_idx in tqdm(sorted(chosen_by_layer.keys()), desc="Analyzing layers"):
        # Training data (for computing directions)
        pos_acts_train = chosen_by_layer[layer_idx]
        neg_acts_train = rejected_by_layer[layer_idx]
        
        # Evaluation data (for measuring metrics)
        pos_acts_eval = chosen_eval_by_layer[layer_idx]
        neg_acts_eval = rejected_eval_by_layer[layer_idx]
        
        # Compute directions on training data
        directions = compute_all_directions(pos_acts_train, neg_acts_train)
        
        layer_results = {}
        for method in methods:
            if method not in directions:
                continue
            
            # Evaluate on evaluation data
            metrics, pos_proj, neg_proj = compute_separation_metrics(
                pos_acts_eval, neg_acts_eval, directions[method], use_cross_val=True
            )
            layer_results[method] = {
                'metrics': metrics,
                'direction': directions[method]
            }
        
        results_by_layer[layer_idx] = layer_results

    return results_by_layer


# =============================================================================
# CROSS-STAGE ANALYSIS FUNCTIONS WITH IMPROVED PLOTS
# =============================================================================

def compare_layer_wise_across_stages(
    stage_results: Dict[str, Dict[int, Dict]],
    output_dir: str,
    metric_name: str = 'classification_accuracy',
    methods: List[str] = ['mean_diff', 'logistic_regression']
):
    """Create comprehensive layer-wise comparison across training stages."""
    print(f"\nCreating cross-stage layer-wise comparison for {metric_name}...")
    
    stages = list(stage_results.keys())
    n_stages = len(stages)
    
    fig = plt.figure(figsize=(18, 6 * len(methods)), facecolor='white')
    
    for method_idx, method in enumerate(methods):
        # Extract data for all stages
        stage_data = {}
        for stage in stages:
            if stage not in stage_results:
                continue
            layers = sorted(stage_results[stage].keys())
            values = [stage_results[stage][l][method]['metrics'][metric_name] for l in layers]
            stage_data[stage] = (layers, values)
        
        # Plot 1: Line plot comparison with improved styling
        ax1 = plt.subplot(len(methods), 3, method_idx * 3 + 1)
        for stage in stages:
            if stage in stage_data:
                layers, values = stage_data[stage]
                ax1.plot(layers, values, marker='o', linewidth=3, markersize=8, 
                        label=stage.upper(), color=COLORS.get(stage, 'gray'), 
                        alpha=0.85, markeredgewidth=1.5, markeredgecolor='white')
        
        ax1.axhline(y=0.5, color='gray', linestyle='--', linewidth=2, alpha=0.4, label='Random', zorder=0)
        ax1.set_xlabel('Layer Index', fontsize=13, fontweight='bold')
        ax1.set_ylabel(metric_name.replace('_', ' ').title(), fontsize=13, fontweight='bold')
        ax1.set_title(f'{method.replace("_", " ").title()}\nEvolution Across Layers', 
                     fontsize=14, fontweight='bold', pad=15)
        ax1.legend(fontsize=11, frameon=True, shadow=True, fancybox=True)
        ax1.grid(True, alpha=0.2, linestyle='--')
        ax1.spines['top'].set_visible(False)
        ax1.spines['right'].set_visible(False)
        ax1.set_ylim(bottom=0.45)
        
        # Plot 2: Heatmap comparison with improved styling
        # ax2 = plt.subplot(len(methods), 3, method_idx * 3 + 2)
        
        # all_layers = sorted(set().union(*[set(stage_results[s].keys()) for s in stages if s in stage_results]))
        # matrix = np.zeros((len(stages), len(all_layers)))
        
        # for i, stage in enumerate(stages):
        #     if stage in stage_results:
        #         for j, layer in enumerate(all_layers):
        #             if layer in stage_results[stage]:
        #                 matrix[i, j] = stage_results[stage][layer][method]['metrics'][metric_name]
        
        # im = ax2.imshow(matrix, cmap='RdYlGn', aspect='auto', vmin=0.5, vmax=1.0, interpolation='bilinear')
        # ax2.set_yticks(range(len(stages)))
        # ax2.set_yticklabels([s.upper() for s in stages], fontsize=11, fontweight='bold')
        # ax2.set_xlabel('Layer Index', fontsize=13, fontweight='bold')
        # ax2.set_title(f'{method.replace("_", " ").title()}\nHeatmap View', 
        #              fontsize=14, fontweight='bold', pad=15)
        
        # Add colorbar with better styling
        # cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
        # cbar.set_label(metric_name.replace('_', ' ').title(), fontsize=11, fontweight='bold', rotation=270, labelpad=20)
        # cbar.ax.tick_params(labelsize=10)
        
        # Plot 3: Improvement from pretrained with better styling
        # ax3 = plt.subplot(len(methods), 3, method_idx * 3 + 3)
        
        # if 'pretrained' in stage_data and len(stages) > 1:
        #     pretrained_layers, pretrained_values = stage_data['pretrained']
            
        #     for stage in stages:
        #         if stage != 'pretrained' and stage in stage_data:
        #             layers, values = stage_data[stage]
        #             improvements = [values[i] - pretrained_values[i] for i in range(len(layers))]
        #             ax3.plot(layers, improvements, marker='s', linewidth=3, markersize=8,
        #                     label=f'{stage.upper()} - Pretrained', color=COLORS.get(stage, 'gray'),
        #                     markeredgewidth=1.5, markeredgecolor='white', alpha=0.85)
            
        #     ax3.axhline(y=0, color='black', linestyle='-', linewidth=2, alpha=0.3)
        #     ax3.set_xlabel('Layer Index', fontsize=13, fontweight='bold')
        #     ax3.set_ylabel('Δ Accuracy from Pretrained', fontsize=13, fontweight='bold')
        #     ax3.set_title(f'{method.replace("_", " ").title()}\nImprovement Over Baseline', 
        #                  fontsize=14, fontweight='bold', pad=15)
        #     ax3.legend(fontsize=11, frameon=True, shadow=True, fancybox=True)
        #     ax3.grid(True, alpha=0.2, linestyle='--')
        #     ax3.spines['top'].set_visible(False)
        #     ax3.spines['right'].set_visible(False)
            
        #     # Add shading for positive/negative regions
        #     ax3.fill_between(layers, 0, max(ax3.get_ylim()), 
        #                    alpha=0.05, color='green', label='_nolegend_')
        #     ax3.fill_between(layers, min(ax3.get_ylim()), 0,
        #                    alpha=0.05, color='red', label='_nolegend_')
    
    plt.tight_layout(pad=2.0)
    plt.savefig(os.path.join(output_dir, f'cross_stage_layerwise_{metric_name}.pdf'),
                **PLOT_CONFIG)
    plt.close()


def analyze_direction_transfer(
    stage_activations: Dict[str, Tuple[Dict, Dict]],
    stage_directions: Dict[str, Dict[str, np.ndarray]],
    output_dir: str,
    best_layer: int
):
    """Test if directions learned on one stage work well on other stages."""
    print("\nAnalyzing direction transfer across stages...")
    
    stages = list(stage_activations.keys())
    methods = list(stage_directions[stages[0]].keys())
    
    # Create transfer matrix: rows = direction source, cols = test stage
    fig, axes = plt.subplots(1, len(methods), figsize=(8*len(methods), 7), facecolor='white')
    if len(methods) == 1:
        axes = [axes]
    
    for method_idx, method in enumerate(methods):
        transfer_matrix = np.zeros((len(stages), len(stages)))
        
        for i, source_stage in enumerate(stages):
            direction = stage_directions[source_stage][method]
            
            for j, test_stage in enumerate(stages):
                chosen_acts = stage_activations[test_stage][0][best_layer]
                rejected_acts = stage_activations[test_stage][1][best_layer]
                
                metrics, _, _ = compute_separation_metrics(
                    chosen_acts, rejected_acts, direction, use_cross_val=True
                )
                transfer_matrix[i, j] = metrics['classification_accuracy']
        
        # Plot with improved styling
        ax = axes[method_idx]
        im = ax.imshow(transfer_matrix, cmap='RdYlGn', vmin=0.5, vmax=1.0, interpolation='nearest')
        
        ax.set_xticks(range(len(stages)))
        ax.set_yticks(range(len(stages)))
        ax.set_xticklabels([s.upper() for s in stages], fontsize=12, fontweight='bold')
        ax.set_yticklabels([s.upper() for s in stages], fontsize=12, fontweight='bold')
        ax.set_xlabel('Test Stage', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_ylabel('Direction Source', fontsize=14, fontweight='bold', labelpad=10)
        ax.set_title(f'{method.replace("_", " ").title()}\nDirection Transfer Analysis', 
                    fontsize=15, fontweight='bold', pad=15)
        
        # Add text annotations with better styling
        for i in range(len(stages)):
            for j in range(len(stages)):
                value = transfer_matrix[i, j]
                text_color = 'white' if value < 0.7 else 'black'
                # Add background circle for better readability
                circle = plt.Circle((j, i), 0.35, color='white', alpha=0.3, zorder=1)
                ax.add_patch(circle)
                ax.text(j, i, f'{value:.3f}',
                       ha='center', va='center', color=text_color, 
                       fontsize=13, fontweight='bold', zorder=2)
        
        # Add colorbar with better styling
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Accuracy', fontsize=12, fontweight='bold', rotation=270, labelpad=20)
        cbar.ax.tick_params(labelsize=11)
        
        # Add grid for better readability
        ax.set_xticks([x - 0.5 for x in range(1, len(stages))], minor=True)
        ax.set_yticks([y - 0.5 for y in range(1, len(stages))], minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=1.5, alpha=0.3)
    
    plt.tight_layout(pad=2.0)
    plt.savefig(os.path.join(output_dir, 'direction_transfer_analysis.pdf'), **PLOT_CONFIG)
    plt.close()
    
    # Compute transfer metrics
    transfer_results = {}
    for method in methods:
        transfer_matrix = np.zeros((len(stages), len(stages)))
        
        for i, source_stage in enumerate(stages):
            direction = stage_directions[source_stage][method]
            
            for j, test_stage in enumerate(stages):
                chosen_acts = stage_activations[test_stage][0][best_layer]
                rejected_acts = stage_activations[test_stage][1][best_layer]
                
                metrics, _, _ = compute_separation_metrics(
                    chosen_acts, rejected_acts, direction, use_cross_val=True
                )
                transfer_matrix[i, j] = metrics['classification_accuracy']
        
        # Diagonal = same-stage performance
        # Off-diagonal = transfer performance
        transfer_results[method] = {
            'matrix': transfer_matrix.tolist(),
            'mean_transfer': float(np.mean(transfer_matrix[np.triu_indices(len(stages), k=1)])),
            'mean_diagonal': float(np.mean(np.diag(transfer_matrix))),
        }
    
    with open(os.path.join(output_dir, 'direction_transfer_results.json'), 'w') as f:
        json.dump(transfer_results, f, indent=2)
    
    return transfer_results


def compare_feature_importance_across_stages(
    stage_activations: Dict[str, Tuple[Dict, Dict]],
    output_dir: str,
    best_layer: int,
    top_k: int = 50
):
    """Compare which features are important across different training stages."""
    print("\nComparing feature importance across stages...")
    
    stages = list(stage_activations.keys())
    
    # Compute Cohen's d for each stage
    stage_cohens_d = {}
    for stage in stages:
        chosen_acts = stage_activations[stage][0][best_layer]
        rejected_acts = stage_activations[stage][1][best_layer]
        
        chosen_mean = chosen_acts.mean(axis=0)
        rejected_mean = rejected_acts.mean(axis=0)
        mean_diff = chosen_mean - rejected_mean
        
        chosen_std = chosen_acts.std(axis=0)
        rejected_std = rejected_acts.std(axis=0)
        pooled_std = np.sqrt((chosen_std**2 + rejected_std**2) / 2)
        
        cohens_d = mean_diff / (pooled_std + 1e-8)
        stage_cohens_d[stage] = cohens_d
    
    # Find top features for each stage
    stage_top_features = {}
    for stage in stages:
        top_idx = np.argsort(np.abs(stage_cohens_d[stage]))[::-1][:top_k]
        stage_top_features[stage] = set(top_idx)
    
    # Compute overlap between stages
    fig, axes = plt.subplots(1, 2, figsize=(18, 8), facecolor='white')
    
    # Plot 1: Feature overlap Venn-style with improved styling
    ax = axes[0]
    
    if len(stages) >= 2:
        # Jaccard similarity matrix
        n = len(stages)
        overlap_matrix = np.zeros((n, n))
        
        for i, s1 in enumerate(stages):
            for j, s2 in enumerate(stages):
                intersection = len(stage_top_features[s1] & stage_top_features[s2])
                union = len(stage_top_features[s1] | stage_top_features[s2])
                overlap_matrix[i, j] = intersection / union if union > 0 else 0
        
        im = ax.imshow(overlap_matrix, cmap='YlOrRd', vmin=0, vmax=1, interpolation='nearest')
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels([s.upper() for s in stages], fontsize=13, fontweight='bold')
        ax.set_yticklabels([s.upper() for s in stages], fontsize=13, fontweight='bold')
        ax.set_title(f'Top-{top_k} Feature Overlap\n(Jaccard Similarity)', 
                    fontsize=15, fontweight='bold', pad=15)
        
        for i in range(n):
            for j in range(n):
                value = overlap_matrix[i, j]
                text_color = 'white' if value < 0.5 else 'black'
                # Add subtle background
                circle = plt.Circle((j, i), 0.35, color='white', alpha=0.2, zorder=1)
                ax.add_patch(circle)
                ax.text(j, i, f'{value:.2f}',
                       ha='center', va='center', color=text_color, 
                       fontsize=14, fontweight='bold', zorder=2)
        
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Jaccard Similarity', fontsize=12, fontweight='bold', rotation=270, labelpad=20)
        cbar.ax.tick_params(labelsize=11)
        
        # Add grid
        ax.set_xticks([x - 0.5 for x in range(1, n)], minor=True)
        ax.set_yticks([y - 0.5 for y in range(1, n)], minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=1.5, alpha=0.2)
    
    # Plot 2: Correlation of effect sizes across stages with improved styling
    ax = axes[1]
    
    if len(stages) >= 2:
        # Compute Spearman correlation of Cohen's d values
        n = len(stages)
        corr_matrix = np.zeros((n, n))
        
        for i, s1 in enumerate(stages):
            for j, s2 in enumerate(stages):
                corr, _ = spearmanr(stage_cohens_d[s1], stage_cohens_d[s2])
                corr_matrix[i, j] = corr
        
        im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1, interpolation='nearest')
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels([s.upper() for s in stages], fontsize=13, fontweight='bold')
        ax.set_yticklabels([s.upper() for s in stages], fontsize=13, fontweight='bold')
        ax.set_title('Effect Size Correlation\n(Spearman ρ)', 
                    fontsize=15, fontweight='bold', pad=15)
        
        for i in range(n):
            for j in range(n):
                value = corr_matrix[i, j]
                text_color = 'white' if abs(value) < 0.5 else 'black'
                # Add subtle background
                circle = plt.Circle((j, i), 0.35, color='white', alpha=0.2, zorder=1)
                ax.add_patch(circle)
                ax.text(j, i, f'{value:.2f}',
                       ha='center', va='center', color=text_color, 
                       fontsize=14, fontweight='bold', zorder=2)
        
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Spearman ρ', fontsize=12, fontweight='bold', rotation=270, labelpad=20)
        cbar.ax.tick_params(labelsize=11)
        
        # Add grid
        ax.set_xticks([x - 0.5 for x in range(1, n)], minor=True)
        ax.set_yticks([y - 0.5 for y in range(1, n)], minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=1.5, alpha=0.2)
    
    plt.tight_layout(pad=2.0)
    plt.savefig(os.path.join(output_dir, 'feature_importance_comparison.pdf'), **PLOT_CONFIG)
    plt.close()
    
    # Additional plot: Show top features across stages with improved styling
    fig, ax = plt.subplots(figsize=(16, 9), facecolor='white')
    
    # Get union of all top features
    all_top_features = sorted(set().union(*stage_top_features.values()))[:top_k]
    
    # Create matrix showing effect size for these features across stages
    matrix = np.zeros((len(stages), len(all_top_features)))
    for i, stage in enumerate(stages):
        for j, feat in enumerate(all_top_features):
            matrix[i, j] = stage_cohens_d[stage][feat]
    
    im = ax.imshow(matrix, cmap='RdBu_r', aspect='auto', vmin=-2, vmax=2, interpolation='bilinear')
    ax.set_yticks(range(len(stages)))
    ax.set_yticklabels([s.upper() for s in stages], fontsize=12, fontweight='bold')
    ax.set_xlabel('Feature Index (top features across all stages)', 
                  fontsize=13, fontweight='bold', labelpad=10)
    ax.set_title(f"Effect Size (Cohen's d) for Top-{top_k} Features", 
                fontsize=15, fontweight='bold', pad=15)
    
    cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.04)
    cbar.set_label("Cohen's d", fontsize=12, fontweight='bold', rotation=270, labelpad=20)
    cbar.ax.tick_params(labelsize=11)
    
    # Add subtle grid
    ax.set_xticks([x - 0.5 for x in range(1, len(all_top_features))], minor=True)
    ax.set_yticks([y - 0.5 for y in range(1, len(stages))], minor=True)
    ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.2)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_importance_heatmap.pdf'), **PLOT_CONFIG)
    plt.close()
    
    # Save overlap statistics
    overlap_stats = {}
    for i, s1 in enumerate(stages):
        for s2 in stages[i+1:]:
            intersection = len(stage_top_features[s1] & stage_top_features[s2])
            union = len(stage_top_features[s1] | stage_top_features[s2])
            overlap_stats[f'{s1}_vs_{s2}'] = {
                'intersection': intersection,
                'union': union,
                'jaccard': intersection / union if union > 0 else 0,
                'overlap_pct': intersection / top_k
            }
    
    with open(os.path.join(output_dir, 'feature_overlap_stats.json'), 'w') as f:
        json.dump(overlap_stats, f, indent=2)


def create_summary_dashboard(
    stage_results: Dict[str, Dict],
    stage_activations: Dict[str, Tuple[Dict, Dict]],
    output_dir: str,
    best_layer: int,
    methods: List[str] = ['mean_diff', 'logistic_regression']
):
    """Create a comprehensive summary dashboard comparing all stages."""
    print("\nCreating summary dashboard...")
    
    stages = list(stage_results.keys())
    
    # Create figure with custom layout
    fig = plt.figure(figsize=(22, 14), facecolor='white')
    gs = GridSpec(3, 3, figure=fig, hspace=0.35, wspace=0.35)
    
    # 1. Best layer accuracy comparison with improved styling
    ax1 = fig.add_subplot(gs[0, 0])
    x_pos = np.arange(len(stages))
    width = 0.35
    
    for method_idx, method in enumerate(methods):
        accuracies = [stage_results[s][best_layer][method]['metrics']['classification_accuracy'] 
                     for s in stages]
        bars = ax1.bar(x_pos + (width * method_idx), accuracies, width, 
                      label=method.replace('_', ' ').title(),
                      alpha=0.85, edgecolor='white', linewidth=1.5)
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    ax1.set_ylabel('Accuracy', fontweight='bold', fontsize=12)
    ax1.set_title(f'Best Layer ({best_layer}) Accuracy Comparison', 
                 fontweight='bold', fontsize=13, pad=10)
    ax1.set_xticks(x_pos + width / 2)
    ax1.set_xticklabels([s.upper() for s in stages], fontweight='bold')
    ax1.legend(fontsize=10, frameon=True, shadow=True)
    ax1.grid(True, alpha=0.2, axis='y', linestyle='--')
    ax1.axhline(y=0.5, color='red', linestyle='--', alpha=0.4, linewidth=2)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_ylim(bottom=0.45)
    
    # 2. Effect size progression with improved styling
    ax2 = fig.add_subplot(gs[0, 1])
    for method in methods:
        cohens_d_values = [stage_results[s][best_layer][method]['metrics']['cohens_d'] 
                          for s in stages]
        ax2.plot(stages, cohens_d_values, marker='o', linewidth=3.5, markersize=12,
                label=method.replace('_', ' ').title(), markeredgewidth=2, 
                markeredgecolor='white', alpha=0.85)
    
    ax2.set_ylabel("Cohen's d", fontweight='bold', fontsize=12)
    ax2.set_title('Effect Size Evolution Across Stages', 
                 fontweight='bold', fontsize=13, pad=10)
    ax2.set_xticklabels([s.upper() for s in stages], fontweight='bold')
    ax2.legend(fontsize=10, frameon=True, shadow=True)
    ax2.grid(True, alpha=0.2, linestyle='--')
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    
    # 3. Mean separation distance with improved styling
    ax3 = fig.add_subplot(gs[0, 2])
    for method in methods:
        distances = [stage_results[s][best_layer][method]['metrics']['euclidean_distance'] 
                    for s in stages]
        ax3.plot(stages, distances, marker='s', linewidth=3.5, markersize=12,
                label=method.replace('_', ' ').title(), markeredgewidth=2,
                markeredgecolor='white', alpha=0.85)
    
    ax3.set_ylabel('Euclidean Distance', fontweight='bold', fontsize=12)
    ax3.set_title('Mean Separation Distance', fontweight='bold', fontsize=13, pad=10)
    ax3.set_xticklabels([s.upper() for s in stages], fontweight='bold')
    ax3.legend(fontsize=10, frameon=True, shadow=True)
    ax3.grid(True, alpha=0.2, linestyle='--')
    ax3.spines['top'].set_visible(False)
    ax3.spines['right'].set_visible(False)
    
    # 4-6. Layer-wise progression for each stage with improved styling
    for idx, stage in enumerate(stages):
        ax = fig.add_subplot(gs[1, idx])
        
        layers = sorted(stage_results[stage].keys())
        for method in methods:
            accuracies = [stage_results[stage][l][method]['metrics']['classification_accuracy'] 
                         for l in layers]
            ax.plot(layers, accuracies, marker='o', linewidth=2.5, markersize=7,
                   label=method.replace('_', ' ').title(), markeredgewidth=1.5,
                   markeredgecolor='white', alpha=0.85)
        
        ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.4, linewidth=2)
        ax.axvline(x=best_layer, color='red', linestyle=':', alpha=0.6, 
                  linewidth=2.5, label='Best Layer')
        ax.set_xlabel('Layer', fontweight='bold', fontsize=11)
        ax.set_ylabel('Accuracy', fontweight='bold', fontsize=11)
        ax.set_title(f'{stage.upper()}\nLayer Progression', 
                    fontweight='bold', fontsize=12, pad=10)
        ax.legend(fontsize=9, frameon=True)
        ax.grid(True, alpha=0.2, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_ylim(bottom=0.45)
    
    # 7. Summary statistics table with improved styling
    ax7 = fig.add_subplot(gs[2, :])
    ax7.axis('tight')
    ax7.axis('off')
    
    # Create summary table
    table_data = [['Stage', 'Best Accuracy', 'Mean Accuracy', 'Max Cohen\'s d', 'Avg Distance']]
    
    for stage in stages:
        layers = sorted(stage_results[stage].keys())
        accuracies = [stage_results[stage][l]['logistic_regression']['metrics']['classification_accuracy'] 
                     for l in layers]
        cohens_d = [abs(stage_results[stage][l]['logistic_regression']['metrics']['cohens_d']) 
                   for l in layers]
        distances = [stage_results[stage][l]['logistic_regression']['metrics']['euclidean_distance'] 
                    for l in layers]
        
        table_data.append([
            stage.upper(),
            f"{max(accuracies):.4f}",
            f"{np.mean(accuracies):.4f}",
            f"{max(cohens_d):.4f}",
            f"{np.mean(distances):.2f}"
        ])
    
    table = ax7.table(cellText=table_data, cellLoc='center', loc='center',
                     colWidths=[0.15, 0.2, 0.2, 0.2, 0.2])
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1, 3)
    
    # Style header row with gradient
    for i in range(len(table_data[0])):
        cell = table[(0, i)]
        cell.set_facecolor('#2c3e50')
        cell.set_text_props(weight='bold', color='white', fontsize=13)
        cell.set_edgecolor('white')
        cell.set_linewidth(2)
    
    # Style data rows with stage colors
    for i, stage in enumerate(stages, 1):
        for j in range(len(table_data[0])):
            cell = table[(i, j)]
            if j == 0:
                cell.set_facecolor(COLORS.get(stage, 'lightgray'))
                cell.set_text_props(weight='bold', fontsize=12)
            else:
                cell.set_facecolor('#ecf0f1')
                cell.set_text_props(fontsize=11)
            cell.set_edgecolor('white')
            cell.set_linewidth(1.5)
    
    fig.suptitle('Cross-Stage Probe Analysis Summary Dashboard', 
                fontsize=18, fontweight='bold', y=0.98)
    
    plt.savefig(os.path.join(output_dir, 'summary_dashboard.pdf'), **PLOT_CONFIG)
    plt.close()


def visualize_2d_separation_multi_stage(
    stage_activations: Dict[str, Tuple[Dict, Dict]],
    stage_directions: Dict[str, Dict[str, np.ndarray]],
    output_dir: str,
    best_layer: int,
    method: str = 'mean_diff'
):
    """Visualize 2D separation for all stages in one plot."""
    print(f"\nCreating multi-stage 2D separation visualization...")
    
    stages = list(stage_activations.keys())
    n_stages = len(stages)
    
    fig, axes = plt.subplots(1, n_stages, figsize=(8*n_stages, 7), facecolor='white')
    if n_stages == 1:
        axes = [axes]
    
    for idx, stage in enumerate(stages):
        chosen_acts = stage_activations[stage][0][best_layer]
        rejected_acts = stage_activations[stage][1][best_layer]
        direction = stage_directions[stage][method]
        
        all_acts = np.vstack([chosen_acts, rejected_acts])
        pca = PCA(n_components=2)
        acts_2d = pca.fit_transform(all_acts)
        
        chosen_2d = acts_2d[:len(chosen_acts)]
        rejected_2d = acts_2d[len(chosen_acts):]
        
        # Project direction onto 2D
        direction_2d = pca.transform(direction.reshape(1, -1))[0]
        direction_2d = direction_2d / (np.linalg.norm(direction_2d) + 1e-8)
        
        ax = axes[idx]
        
        # Plot rejected with subtle border
        ax.scatter(rejected_2d[:, 0], rejected_2d[:, 1], 
                  c=COLORS['rejected'], alpha=0.45, s=60, 
                  label='Rejected', edgecolors='white', linewidths=0.5)
        
        # Plot chosen with subtle border
        ax.scatter(chosen_2d[:, 0], chosen_2d[:, 1], 
                  c=COLORS['chosen'], alpha=0.65, s=60, 
                  label='Chosen', edgecolors='white', linewidths=0.5)
        
        # Plot direction arrow with better styling
        scale = max(abs(acts_2d[:, 0]).max(), abs(acts_2d[:, 1]).max()) * 0.45
        ax.arrow(0, 0, direction_2d[0]*scale, direction_2d[1]*scale,
                head_width=scale*0.1, head_length=scale*0.06, 
                fc=COLORS['direction'], ec=COLORS['direction'],
                linewidth=4, label='Direction', alpha=0.9, zorder=5)
        
        # Centroids with halos
        # Chosen centroid
        # ax.scatter(*chosen_2d.mean(axis=0), c=COLORS['chosen'], 
        #           marker='*', s=500, alpha=0.3, zorder=8, 
        #           edgecolors='black')  # Halo
        # ax.scatter(*chosen_2d.mean(axis=0), c='darkblue', 
        #           marker='*', s=350, edgecolors='black', linewidths=3, 
        #           label='Chosen Center', zorder=10)
        ax.scatter(*chosen_2d.mean(axis=0), c='darkblue', 
                  marker='*', s=300, 
                  edgecolors='black', linewidths=2, label='Chosen Center', zorder=10)
        
        # Rejected centroid
        # ax.scatter(*rejected_2d.mean(axis=0), c=COLORS['rejected'], 
        #           marker='*', s=500, alpha=0.3, zorder=8,
        #           edgecolors='black')  # Halo
        ax.scatter(*rejected_2d.mean(axis=0), c='darkred', 
                  marker='*', s=300, 
                  edgecolors='black', linewidths=2, label='Rejected Center', zorder=10)
        
        ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%})', 
                     fontweight='bold', fontsize=12)
        ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%})', 
                     fontweight='bold', fontsize=12)
        ax.set_title(f'{stage.upper()}\nLayer {best_layer}', 
                    fontsize=14, fontweight='bold', pad=12)
        ax.legend(loc='best', fontsize=10, frameon=True, shadow=True, 
                 fancybox=True, framealpha=0.95)
        ax.grid(True, alpha=0.2, linestyle='--')
        ax.set_aspect('equal', adjustable='box')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_facecolor('#fafafa')
    
    plt.tight_layout(pad=2.0)
    plt.savefig(os.path.join(output_dir, f'2d_separation_all_stages_{method}.pdf'), 
                **PLOT_CONFIG)
    plt.close()


def main():
    parser = argparse.ArgumentParser(description="Enhanced Probe Analysis for APO")
    
    parser.add_argument("--model-name", type=str, default="meta-llama/Llama-3.2-1B")
    parser.add_argument("--dataset", type=str, default="Anthropic/hh-rlhf")
    parser.add_argument("--dataset-language", type=str, default="amh")
    parser.add_argument("--max-samples", type=int, default=500, 
                       help="Number of samples for training (can be less than 500)")
    parser.add_argument("--min-eval-samples", type=int, default=500,
                       help="Minimum number of samples for evaluation (default: 500)")
    parser.add_argument("--output-dir", type=str, default="./probe_analysis_enhanced")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--sft-model-path", type=str, default=None)
    parser.add_argument("--po-model-path", type=str, default=None)
    parser.add_argument("--position", type=str, choices=["completion_mean", "last_token", "mean", "max", "last_prompt"],
                       default="completion_mean")
    parser.add_argument("--methods", type=str, nargs='+',
                       default=['logistic_regression'])

    args = parser.parse_args()

    # Ensure eval_samples is at least min_eval_samples
    eval_samples = max(args.max_samples, args.min_eval_samples)
    
    if args.max_samples < args.min_eval_samples:
        print(f"ℹ Info: Training on {args.max_samples} samples, but evaluating on {eval_samples} samples (minimum {args.min_eval_samples})")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    print("="*70)
    print("ENHANCED PROBE ANALYSIS FOR APO")
    print("="*70)
    print(f"Model: {args.model_name}")
    print(f"Dataset: {args.dataset}")
    print(f"Training samples: {args.max_samples}")
    print(f"Evaluation samples: {eval_samples}")
    print(f"Position: {args.position}")
    print(f"Output: {args.output_dir}")
    print("="*70)

    dataset = load_preference_dataset(args.dataset, max_samples=eval_samples, language=args.dataset_language)

    # Sanity check: randomly swap chosen/rejected
    # def swap_chosen_rejected(example, swap_prob=0.5):
    #     if random.random() < swap_prob:
    #         example['chosen'], example['rejected'] = example['rejected'], example['chosen']
    #     return example
    # dataset = [swap_chosen_rejected(example) for example in dataset]

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.chat_template = "{{- bos_token }}\n{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Storage for cross-stage analysis
    stage_results = {}
    stage_activations = {}
    stage_directions = {}

    # Analyze pretrained model
    print("\n" + "="*70)
    print("ANALYZING PRETRAINED MODEL")
    print("="*70)

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name, 
        device_map="auto", 
        torch_dtype="auto"
    )

    # Extract training activations
    print(f"\nExtracting training activations ({args.max_samples} samples)...")
    chosen_by_layer_train, rejected_by_layer_train = extract_activations_all_layers(
        model, tokenizer, dataset, args.max_samples, args.position
    )
    
    # Extract evaluation activations if needed
    if eval_samples > args.max_samples:
        print(f"\nExtracting evaluation activations ({eval_samples} samples)...")
        chosen_by_layer_eval, rejected_by_layer_eval = extract_activations_all_layers(
            model, tokenizer, dataset, eval_samples, args.position
        )
    else:
        chosen_by_layer_eval = chosen_by_layer_train
        rejected_by_layer_eval = rejected_by_layer_train

    layer_results = analyze_layer_wise_accuracy(
        chosen_by_layer_train, rejected_by_layer_train,
        chosen_by_layer_eval, rejected_by_layer_eval,
        args.methods
    )

    best_layer = max(layer_results.keys(), 
                     key=lambda l: layer_results[l]['logistic_regression']['metrics']['classification_accuracy'])
    print(f"\n✓ Best layer for pretrained: {best_layer}")

    # Store results (use eval data for cross-stage analysis)
    stage_results['pretrained'] = layer_results
    stage_activations['pretrained'] = (chosen_by_layer_eval, rejected_by_layer_eval)
    
    # Compute and store directions for best layer (use training data)
    pos_acts = chosen_by_layer_train[best_layer]
    neg_acts = rejected_by_layer_train[best_layer]
    directions = compute_all_directions(pos_acts, neg_acts)
    stage_directions['pretrained'] = directions

    del model
    torch.cuda.empty_cache()

    # Analyze SFT model if provided
    if args.sft_model_path:
        print("\n" + "="*70)
        print("ANALYZING SFT MODEL")
        print("="*70)
        
        sft_model = AutoModelForCausalLM.from_pretrained(
            args.sft_model_path, device_map="auto", torch_dtype="auto"
        )
        sft_tokenizer = AutoTokenizer.from_pretrained(args.sft_model_path)
        
        # Extract training activations
        print(f"\nExtracting training activations ({args.max_samples} samples)...")
        sft_chosen_train, sft_rejected_train = extract_activations_all_layers(
            sft_model, sft_tokenizer, dataset, args.max_samples, args.position
        )
        
        # Extract evaluation activations if needed
        if eval_samples > args.max_samples:
            print(f"\nExtracting evaluation activations ({eval_samples} samples)...")
            sft_chosen_eval, sft_rejected_eval = extract_activations_all_layers(
                sft_model, sft_tokenizer, dataset, eval_samples, args.position
            )
        else:
            sft_chosen_eval = sft_chosen_train
            sft_rejected_eval = sft_rejected_train
        
        sft_results = analyze_layer_wise_accuracy(
            sft_chosen_train, sft_rejected_train,
            sft_chosen_eval, sft_rejected_eval,
            args.methods
        )
        
        stage_results['sft'] = sft_results
        stage_activations['sft'] = (sft_chosen_eval, sft_rejected_eval)
        
        sft_pos = sft_chosen_train[best_layer]
        sft_neg = sft_rejected_train[best_layer]
        stage_directions['sft'] = compute_all_directions(sft_pos, sft_neg)
        
        del sft_model
        torch.cuda.empty_cache()

    # Analyze PO model if provided
    if args.po_model_path:
        print("\n" + "="*70)
        print("ANALYZING PO MODEL")
        print("="*70)
        
        po_model = AutoModelForCausalLM.from_pretrained(
            args.po_model_path, device_map="auto", torch_dtype="auto"
        )
        po_tokenizer = AutoTokenizer.from_pretrained(args.po_model_path)
        
        # Extract training activations
        print(f"\nExtracting training activations ({args.max_samples} samples)...")
        po_chosen_train, po_rejected_train = extract_activations_all_layers(
            po_model, po_tokenizer, dataset, args.max_samples, args.position
        )
        
        # Extract evaluation activations if needed
        if eval_samples > args.max_samples:
            print(f"\nExtracting evaluation activations ({eval_samples} samples)...")
            po_chosen_eval, po_rejected_eval = extract_activations_all_layers(
                po_model, po_tokenizer, dataset, eval_samples, args.position
            )
        else:
            po_chosen_eval = po_chosen_train
            po_rejected_eval = po_rejected_train
        
        po_results = analyze_layer_wise_accuracy(
            po_chosen_train, po_rejected_train,
            po_chosen_eval, po_rejected_eval,
            args.methods
        )
        
        stage_results['po'] = po_results
        stage_activations['po'] = (po_chosen_eval, po_rejected_eval)
        
        po_pos = po_chosen_train[best_layer]
        po_neg = po_rejected_train[best_layer]
        stage_directions['po'] = compute_all_directions(po_pos, po_neg)
        
        del po_model
        torch.cuda.empty_cache()

    # ==================================================================
    # CROSS-STAGE ANALYSIS WITH IMPROVED VISUALIZATIONS
    # ==================================================================

    print("\n" + "="*70)
    print("CROSS-STAGE ANALYSIS")
    print("="*70)

    if len(stage_results) > 1:
        # Layer-wise comparison
        compare_layer_wise_across_stages(
            stage_results, args.output_dir, 
            'classification_accuracy', args.methods
        )

        compare_layer_wise_across_stages(
            stage_results, args.output_dir,
            'cohens_d', args.methods
        )

        # Direction transfer analysis
        transfer_results = analyze_direction_transfer(
            stage_activations, stage_directions, args.output_dir, best_layer
        )
        
        # Feature importance comparison
        compare_feature_importance_across_stages(
            stage_activations, args.output_dir, best_layer, top_k=50
        )
        
        # Summary dashboard
        create_summary_dashboard(
            stage_results, stage_activations, args.output_dir, best_layer, args.methods
        )
        
        # Multi-stage 2D visualization
        for method in args.methods:
            visualize_2d_separation_multi_stage(
                stage_activations, stage_directions, args.output_dir, best_layer, method
            )

    print("\n" + "="*70)
    print("✓ ANALYSIS COMPLETE!")
    print(f"Results saved to: {args.output_dir}")
    print("="*70)
    print(f"\nTraining samples: {args.max_samples}")
    print(f"Evaluation samples: {eval_samples}")
    print("\nGenerated visualizations (high-quality PDFs at 300 DPI):")
    print("  • cross_stage_layerwise_*.pdf - Layer progression across stages")
    print("  • direction_transfer_analysis.pdf - Direction transferability")
    print("  • feature_importance_comparison.pdf - Feature overlap analysis")
    print("  • summary_dashboard.pdf - Comprehensive overview")
    print("  • 2d_separation_all_stages_*.pdf - Visual separation comparison")
    print("="*70)


if __name__ == "__main__":
    main()
